mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
* models matrix * fix typo and install gpu deps * install llvm deps if needed * fix * testops with cuda * remove pip cache since not work * cuda env * install cuda deps * maybe it will work now * i can't read * all tests in matrix * trim down more * opencl stuff in matrix * opencl pip cache * test split * change cuda test exclusion * test * fix cuda maybe * add models * add more n=auto * third thing * fix bug * cache pip more * change name * update tests * try again cause why not * balance * try again... * try apt cache for cuda * try on gpu: * try cuda again * update packages step * replace libz-dev with zlib1g-dev * only cache cuda * why error * fix gpuocelot bug * apt cache err * apt cache to slow? * opt and image in single runner * add a couple n=autos * remove test matrix * try cuda apt cache again * libz-dev -> zlib1g-dev * remove -s since not supported by xdist * the cache takes too long and doesn't work * combine webgpu and metal tests * combine imagenet to c and cpu tests * torch tests with linters * torch back by itself * small windows clang test with torch tests * fix a goofy windows bug * im dumb * bro * clang with linters * fix pylint error * linter not work on windows * try with clang again * clang and imagenet? * install deps * fix * fix quote * clang by itself (windows too slow) * env vars for imagenet * cache pip for metal and webgpu tests * try torch with metal and webgpu * doesn't work, too long * remove -v * try -n=logical * don't use logical * revert accidental thing * remove some prints unless CI * fix print unless CI * ignore speed tests for slow tests * clang windows in matrix (ubuntu being tested in imagenet->c test) * try manual pip cache * fix windows pip cache path * all manual pip cache * fix pip cache dir for macos * print_ci function in helpers * CI as variable, no print_ci * missed one * cuda tests with docker image * remove setup-python action for cuda * python->python3? * remove -s -v * try fix pip cache * maybe fix * try to fix pip cache * is this the path? * maybe cache pip * try again * create wheels dir * ? * cuda pip deps in dockerfile * disable pip cache for clang * image from ghcr instead of docker hub * why is clang like this * fast deps * try use different caches * remove the fast thing * try with lighter image * remove setup python for cuda * small docker and cuda fast deps * ignore a few more tests * cool docker thing (maybe) * oops * quotes * fix docker command * fix bug * ignore train efficientnet test * remove dockerfile (docker stuff takes too long) * remove docker stuff and normal cuda * oops * ignore the tests for cuda * does this work * ignore test_train on slow backends * add space * llvm ignore same tests as cuda * nvm * ignore lr scheduler tests * get some stats * fix ignore bug * remove extra ' * remove and * ignore test for llvm * change ignored tests and durationon all backends * fix * and -> or * ignore some more cuda tests * finally? * does this fix it * remove durations=0 * add some more tests to llvm * make last pytest more readable * fix * don't train efficientnet on cpu * try w/out pip cache * pip cache seems to be generally better * pytest file markers * try apt fast for cuda * use quick install for apt-fast * apt-fast not worth * apt-get to apt * fix typo * suppress warnings * register markers * disable debug on fuzz tests * change marker names * apt update and apt install in one command * update marker names in test.yml * webgpu pytest marker
291 lines
10 KiB
Python
Executable File
291 lines
10 KiB
Python
Executable File
#!/usr/bin/env python
|
|
import unittest
|
|
import numpy as np
|
|
from extra.utils import WINDOWS
|
|
from tinygrad.helpers import getenv
|
|
from tinygrad.jit import TinyJit
|
|
from tinygrad.tensor import Tensor, Device
|
|
from tinygrad.nn import BatchNorm2d, Conv1d, ConvTranspose1d, Conv2d, ConvTranspose2d, Linear, GroupNorm, LayerNorm, LayerNorm2d, Embedding, InstanceNorm
|
|
import torch
|
|
import pytest
|
|
|
|
pytestmark = [pytest.mark.exclude_cuda, pytest.mark.webgpu]
|
|
|
|
class TestNN(unittest.TestCase):
|
|
|
|
def test_batchnorm2d(self, training=False):
|
|
szs = [4, 8, 16, 32]
|
|
for sz in szs:
|
|
# create in tinygrad
|
|
Tensor.training = training
|
|
bn = BatchNorm2d(sz, eps=1e-5, track_running_stats=training)
|
|
bn.weight = Tensor.randn(sz)
|
|
bn.bias = Tensor.randn(sz)
|
|
bn.running_mean = Tensor.randn(sz)
|
|
bn.running_var = Tensor.randn(sz)
|
|
bn.running_var.numpy()[bn.running_var.numpy() < 0] = 0
|
|
|
|
# create in torch
|
|
with torch.no_grad():
|
|
tbn = torch.nn.BatchNorm2d(sz).eval()
|
|
tbn.training = training
|
|
tbn.weight[:] = torch.tensor(bn.weight.numpy())
|
|
tbn.bias[:] = torch.tensor(bn.bias.numpy())
|
|
tbn.running_mean[:] = torch.tensor(bn.running_mean.numpy())
|
|
tbn.running_var[:] = torch.tensor(bn.running_var.numpy())
|
|
|
|
np.testing.assert_allclose(bn.running_mean.numpy(), tbn.running_mean.detach().numpy(), rtol=1e-5, atol=1e-6)
|
|
np.testing.assert_allclose(bn.running_var.numpy(), tbn.running_var.detach().numpy(), rtol=1e-5, atol=1e-6)
|
|
|
|
# trial
|
|
inn = Tensor.randn(2, sz, 3, 3)
|
|
|
|
# in tinygrad
|
|
outt = bn(inn)
|
|
|
|
# in torch
|
|
toutt = tbn(torch.tensor(inn.cpu().numpy()))
|
|
|
|
# close
|
|
np.testing.assert_allclose(outt.numpy(), toutt.detach().numpy(), rtol=5e-4, atol=1e-6)
|
|
|
|
np.testing.assert_allclose(bn.running_mean.numpy(), tbn.running_mean.detach().numpy(), rtol=1e-5, atol=1e-6)
|
|
|
|
np.testing.assert_allclose(bn.running_var.numpy(), tbn.running_var.detach().numpy(), rtol=1e-5, atol=1e-6)
|
|
|
|
def test_batchnorm2d_training(self):
|
|
self.test_batchnorm2d(True)
|
|
|
|
def test_linear(self):
|
|
def _test_linear(x):
|
|
|
|
# create in tinygrad
|
|
model = Linear(in_dim, out_dim)
|
|
z = model(x)
|
|
|
|
# create in torch
|
|
with torch.no_grad():
|
|
torch_layer = torch.nn.Linear(in_dim, out_dim).eval()
|
|
torch_layer.weight[:] = torch.tensor(model.weight.numpy(), dtype=torch.float32)
|
|
torch_layer.bias[:] = torch.tensor(model.bias.numpy(), dtype=torch.float32)
|
|
torch_x = torch.tensor(x.cpu().numpy(), dtype=torch.float32)
|
|
torch_z = torch_layer(torch_x)
|
|
|
|
# test
|
|
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
|
|
|
|
BS, T, in_dim, out_dim = 4, 2, 8, 16
|
|
_test_linear(Tensor.randn(BS, in_dim))
|
|
_test_linear(Tensor.randn(BS, T, in_dim)) # test with more dims
|
|
|
|
def test_conv1d(self):
|
|
BS, C1, W = 4, 16, 224
|
|
C2, K, S, P = 64, 7, 2, 1
|
|
|
|
# create in tinygrad
|
|
layer = Conv1d(C1, C2, kernel_size=K, stride=S, padding=P)
|
|
|
|
# create in torch
|
|
with torch.no_grad():
|
|
torch_layer = torch.nn.Conv1d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
|
|
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
|
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
|
|
|
# test
|
|
x = Tensor.uniform(BS, C1, W)
|
|
z = layer(x)
|
|
torch_x = torch.tensor(x.cpu().numpy())
|
|
torch_z = torch_layer(torch_x)
|
|
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
|
|
|
|
def test_conv2d(self):
|
|
BS, C1, H, W = 4, 16, 224, 224
|
|
C2, K, S, P = 64, 7, 2, 1
|
|
|
|
# create in tinygrad
|
|
layer = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
|
|
|
|
# create in torch
|
|
with torch.no_grad():
|
|
torch_layer = torch.nn.Conv2d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
|
|
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
|
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
|
|
|
# test
|
|
x = Tensor.uniform(BS, C1, H, W)
|
|
z = layer(x)
|
|
torch_x = torch.tensor(x.cpu().numpy())
|
|
torch_z = torch_layer(torch_x)
|
|
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
|
|
|
|
@unittest.skipIf(getenv("CI", "") != "" and (WINDOWS or Device.DEFAULT == "WEBGPU"), "runs out of memory in CI")
|
|
def test_conv_transpose1d(self):
|
|
BS, C1, W = 4, 16, 224
|
|
C2, K, S, P = 64, 7, 2, 1
|
|
|
|
# create in tinygrad
|
|
layer = ConvTranspose1d(C1, C2, kernel_size=K, stride=S, padding=P)
|
|
|
|
# create in torch
|
|
with torch.no_grad():
|
|
torch_layer = torch.nn.ConvTranspose1d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
|
|
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
|
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
|
|
|
# test
|
|
x = Tensor.uniform(BS, C1, W)
|
|
z = layer(x)
|
|
torch_x = torch.tensor(x.cpu().numpy())
|
|
torch_z = torch_layer(torch_x)
|
|
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
|
|
|
|
@unittest.skipIf(getenv("CI", "") != "" and (WINDOWS or Device.DEFAULT == "WEBGPU"), "runs out of memory in CI")
|
|
def test_conv_transpose2d(self):
|
|
BS, C1, H, W = 4, 16, 224, 224
|
|
C2, K, S, P = 64, 7, 2, 1
|
|
|
|
# create in tinygrad
|
|
layer = ConvTranspose2d(C1, C2, kernel_size=K, stride=S, padding=P)
|
|
|
|
# create in torch
|
|
with torch.no_grad():
|
|
torch_layer = torch.nn.ConvTranspose2d(C1, C2, kernel_size=K, stride=S, padding=P).eval()
|
|
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
|
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
|
|
|
# test
|
|
x = Tensor.uniform(BS, C1, H, W)
|
|
z = layer(x)
|
|
torch_x = torch.tensor(x.cpu().numpy())
|
|
torch_z = torch_layer(torch_x)
|
|
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
|
|
|
|
def test_groupnorm(self):
|
|
BS, H, W, C, G = 20, 10, 10, 6, 3
|
|
|
|
# create in tinygrad
|
|
layer = GroupNorm(G, C)
|
|
|
|
# create in torch
|
|
with torch.no_grad():
|
|
torch_layer = torch.nn.GroupNorm(G, C).eval()
|
|
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
|
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
|
|
|
# test
|
|
x = Tensor.randn(BS, C, H, W)
|
|
z = layer(x)
|
|
torch_x = torch.tensor(x.cpu().numpy())
|
|
torch_z = torch_layer(torch_x)
|
|
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
|
|
|
|
def test_layernorm(self):
|
|
N, C, H, W = 20, 5, 10, 10
|
|
|
|
# create in tinygrad
|
|
layer = LayerNorm([H, W])
|
|
|
|
# create in torch
|
|
with torch.no_grad():
|
|
torch_layer = torch.nn.LayerNorm([H, W]).eval()
|
|
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
|
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
|
|
|
# test
|
|
x = Tensor.randn(N, C, H, W)
|
|
z = layer(x)
|
|
torch_x = torch.tensor(x.cpu().numpy())
|
|
torch_z = torch_layer(torch_x)
|
|
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
|
|
|
|
def test_layernorm_2d(self):
|
|
N, C, H, W = 20, 5, 10, 10
|
|
|
|
# create in tinygrad
|
|
layer = LayerNorm2d(C)
|
|
|
|
# create in torch
|
|
with torch.no_grad():
|
|
torch_layer = torch.nn.LayerNorm([C]).eval()
|
|
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
|
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
|
|
|
# test
|
|
x = Tensor.randn(N, C, H, W)
|
|
z = layer(x)
|
|
torch_x = torch.tensor(x.cpu().numpy())
|
|
torch_z = torch_layer(torch_x.permute(0,2,3,1)).permute(0,3,1,2)
|
|
|
|
def test_instancenorm_2d(self):
|
|
N, C, H, W = 20, 5, 10, 10
|
|
|
|
# create in tinygrad
|
|
layer = InstanceNorm(C)
|
|
|
|
# create in torch
|
|
with torch.no_grad():
|
|
torch_layer = torch.nn.InstanceNorm2d(C, affine=True).eval()
|
|
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
|
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
|
|
|
# test
|
|
x = Tensor.randn(N, C, H, W)
|
|
z = layer(x)
|
|
torch_x = torch.tensor(x.cpu().numpy())
|
|
torch_z = torch_layer(torch_x)
|
|
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
|
|
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
|
|
|
|
def test_instancenorm_3d(self):
|
|
N, C, D, H, W = 20, 5, 3, 10, 10
|
|
|
|
# create in tinygrad
|
|
layer = InstanceNorm(C)
|
|
|
|
# create in torch
|
|
with torch.no_grad():
|
|
torch_layer = torch.nn.InstanceNorm3d(C, affine=True).eval()
|
|
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
|
torch_layer.bias[:] = torch.tensor(layer.bias.numpy(), dtype=torch.float32)
|
|
|
|
# test
|
|
x = Tensor.randn(N, C, D, H, W)
|
|
z = layer(x)
|
|
torch_x = torch.tensor(x.cpu().numpy())
|
|
torch_z = torch_layer(torch_x)
|
|
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
|
|
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-3, rtol=5e-3)
|
|
|
|
def test_embedding(self):
|
|
B, T, C, VS = 4, 10, 20, 28
|
|
|
|
# create in tinygrad
|
|
layer = Embedding(VS, C)
|
|
|
|
with torch.no_grad():
|
|
torch_layer = torch.nn.Embedding(VS, C).eval()
|
|
torch_layer.weight[:] = torch.tensor(layer.weight.numpy(), dtype=torch.float32)
|
|
|
|
# test
|
|
x = Tensor(np.random.randint(0, VS, (B, T)).astype(np.float32))
|
|
z = layer(x)
|
|
torch_x = torch.tensor(x.cpu().numpy().astype(np.int32))
|
|
torch_z = torch_layer(torch_x)
|
|
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8)
|
|
|
|
# test with jit enabled
|
|
@TinyJit
|
|
def layer_jit(x):
|
|
return layer(x).realize()
|
|
|
|
for _ in range(3):
|
|
x = Tensor(np.random.randint(0, VS, (B, T)).astype(np.float32))
|
|
z = layer_jit(x)
|
|
torch_x = torch.tensor(x.cpu().numpy().astype(np.int32))
|
|
torch_z = torch_layer(torch_x)
|
|
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|