clean up test_nn.py (#3049)

used Tensor.train decorator, reordered to always tinygrad instances first, and removed redundant idx cast
This commit is contained in:
chenyu
2024-01-08 18:45:03 -05:00
committed by GitHub
parent 3eb3664074
commit ee6a73826b

View File

@@ -1,73 +1,69 @@
#!/usr/bin/env python
import unittest
import numpy as np
from tinygrad.helpers import CI
from tinygrad.jit import TinyJit
from tinygrad.tensor import Tensor, Device
from tinygrad.nn import BatchNorm2d, Conv1d,ConvTranspose1d, Conv2d,ConvTranspose2d, Linear, GroupNorm, LayerNorm,LayerNorm2d, Embedding, InstanceNorm
import torch
from tinygrad import Tensor, Device, TinyJit
from tinygrad.helpers import CI
from tinygrad.nn import BatchNorm2d, Conv1d,ConvTranspose1d, Conv2d,ConvTranspose2d, Linear, GroupNorm, LayerNorm,LayerNorm2d, Embedding, InstanceNorm
@unittest.skipIf(CI and Device.DEFAULT == "CUDA", "slow")
class TestNN(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "no int64 on WebGPU")
def test_sparse_cat_cross_entropy(self):
input = torch.randn(3, 5)
target = torch.empty(3, dtype=torch.long).random_(5)
loss_fun = torch.nn.CrossEntropyLoss(reduction='mean')
loss = loss_fun(input, target)
# create in tinygrad
input = Tensor.randn(3, 5)
target = Tensor.randint((3,), low=0, high=4)
loss = input.sparse_categorical_crossentropy(target)
input_tiny = Tensor(input.detach().numpy())
target_tiny = Tensor(target.detach().numpy())
loss_tiny = input_tiny.sparse_categorical_crossentropy(target_tiny)
torch_input = torch.tensor(input.numpy())
torch_target = torch.tensor(target.numpy(), dtype=torch.long)
torch_loss = torch.nn.CrossEntropyLoss(reduction='mean')(torch_input, torch_target)
np.testing.assert_allclose(loss_tiny.numpy(), loss.detach().numpy(), atol=1e-5, rtol=1e-6)
np.testing.assert_allclose(loss.numpy(), torch_loss.detach().numpy(), atol=1e-5, rtol=1e-6)
def test_batchnorm2d(self, training=False):
szs = [4, 8, 16, 32]
for sz in szs:
# create in tinygrad
Tensor.training = training
bn = BatchNorm2d(sz, eps=1e-5, track_running_stats=training)
bn.weight = Tensor.randn(sz)
bn.bias = Tensor.randn(sz)
bn.running_mean = Tensor.randn(sz)
bn.running_var = Tensor.randn(sz)
bn.running_var.numpy()[bn.running_var.numpy() < 0] = 0
with Tensor.train(training):
szs = [4, 8, 16, 32]
for sz in szs:
# create in tinygrad
bn = BatchNorm2d(sz, eps=1e-5, track_running_stats=training)
bn.weight = Tensor.randn(sz)
bn.bias = Tensor.randn(sz)
bn.running_mean = Tensor.randn(sz)
bn.running_var = Tensor.randn(sz)
bn.running_var.numpy()[bn.running_var.numpy() < 0] = 0
# create in torch
with torch.no_grad():
tbn = torch.nn.BatchNorm2d(sz).eval()
tbn.training = training
tbn.weight[:] = torch.tensor(bn.weight.numpy())
tbn.bias[:] = torch.tensor(bn.bias.numpy())
tbn.running_mean[:] = torch.tensor(bn.running_mean.numpy())
tbn.running_var[:] = torch.tensor(bn.running_var.numpy())
# create in torch
with torch.no_grad():
tbn = torch.nn.BatchNorm2d(sz).eval()
tbn.training = training
tbn.weight[:] = torch.tensor(bn.weight.numpy())
tbn.bias[:] = torch.tensor(bn.bias.numpy())
tbn.running_mean[:] = torch.tensor(bn.running_mean.numpy())
tbn.running_var[:] = torch.tensor(bn.running_var.numpy())
np.testing.assert_allclose(bn.running_mean.numpy(), tbn.running_mean.detach().numpy(), rtol=1e-5, atol=1e-6)
np.testing.assert_allclose(bn.running_var.numpy(), tbn.running_var.detach().numpy(), rtol=1e-5, atol=1e-6)
np.testing.assert_allclose(bn.running_mean.numpy(), tbn.running_mean.detach().numpy(), rtol=1e-5, atol=1e-6)
np.testing.assert_allclose(bn.running_var.numpy(), tbn.running_var.detach().numpy(), rtol=1e-5, atol=1e-6)
# trial
inn = Tensor.randn(2, sz, 3, 3)
# trial
inn = Tensor.randn(2, sz, 3, 3)
# in tinygrad
outt = bn(inn)
# in tinygrad
outt = bn(inn)
# in torch
toutt = tbn(torch.tensor(inn.numpy()))
# in torch
toutt = tbn(torch.tensor(inn.numpy()))
# close
np.testing.assert_allclose(outt.numpy(), toutt.detach().numpy(), rtol=5e-4, atol=1e-6)
np.testing.assert_allclose(bn.running_mean.numpy(), tbn.running_mean.detach().numpy(), rtol=1e-5, atol=1e-6)
np.testing.assert_allclose(bn.running_var.numpy(), tbn.running_var.detach().numpy(), rtol=1e-5, atol=1e-6)
# close
np.testing.assert_allclose(outt.numpy(), toutt.detach().numpy(), rtol=5e-4, atol=1e-6)
np.testing.assert_allclose(bn.running_mean.numpy(), tbn.running_mean.detach().numpy(), rtol=1e-5, atol=1e-6)
np.testing.assert_allclose(bn.running_var.numpy(), tbn.running_var.detach().numpy(), rtol=1e-5, atol=1e-6)
def test_batchnorm2d_training(self):
self.test_batchnorm2d(True)
def test_linear(self):
def _test_linear(x):
def _test_linear(x, in_dim, out_dim):
# create in tinygrad
model = Linear(in_dim, out_dim)
z = model(x)
@@ -84,8 +80,8 @@ class TestNN(unittest.TestCase):
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
BS, T, in_dim, out_dim = 4, 2, 8, 16
_test_linear(Tensor.randn(BS, in_dim))
_test_linear(Tensor.randn(BS, T, in_dim)) # test with more dims
_test_linear(Tensor.randn(BS, in_dim), in_dim, out_dim)
_test_linear(Tensor.randn(BS, T, in_dim), in_dim, out_dim) # test with more dims
def test_conv1d(self):
BS, C1, W = 4, 16, 224//4
@@ -316,14 +312,14 @@ class TestNN(unittest.TestCase):
# test
x = Tensor(np.random.randint(0, vocab_size, (B, T)))
z = layer(x)
torch_x = torch.tensor(x.numpy().astype(np.int32))
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8)
# test with empty input length
x = Tensor(np.random.randint(0, vocab_size, (B, 0)))
z = layer(x)
torch_x = torch.tensor(x.numpy().astype(np.int32))
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8)
@@ -333,9 +329,9 @@ class TestNN(unittest.TestCase):
return layer(x).realize()
for _ in range(3):
x = Tensor(np.random.randint(0, vocab_size, (B, T)).astype(np.float32))
x = Tensor(np.random.randint(0, vocab_size, (B, T)))
z = layer_jit(x)
torch_x = torch.tensor(x.numpy().astype(np.int32))
torch_x = torch.tensor(x.numpy())
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8)