From ee6a73826ba4157e78ba1867eddf0df4c4e69577 Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 8 Jan 2024 18:45:03 -0500 Subject: [PATCH] clean up test_nn.py (#3049) used Tensor.train decorator, reordered to always tinygrad instances first, and removed redundant idx cast --- test/test_nn.py | 100 +++++++++++++++++++++++------------------------- 1 file changed, 48 insertions(+), 52 deletions(-) diff --git a/test/test_nn.py b/test/test_nn.py index 1c11989e20..bf7fd9bb5d 100755 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -1,73 +1,69 @@ #!/usr/bin/env python import unittest import numpy as np -from tinygrad.helpers import CI -from tinygrad.jit import TinyJit -from tinygrad.tensor import Tensor, Device -from tinygrad.nn import BatchNorm2d, Conv1d,ConvTranspose1d, Conv2d,ConvTranspose2d, Linear, GroupNorm, LayerNorm,LayerNorm2d, Embedding, InstanceNorm import torch +from tinygrad import Tensor, Device, TinyJit +from tinygrad.helpers import CI +from tinygrad.nn import BatchNorm2d, Conv1d,ConvTranspose1d, Conv2d,ConvTranspose2d, Linear, GroupNorm, LayerNorm,LayerNorm2d, Embedding, InstanceNorm @unittest.skipIf(CI and Device.DEFAULT == "CUDA", "slow") class TestNN(unittest.TestCase): @unittest.skipIf(Device.DEFAULT == "WEBGPU", "no int64 on WebGPU") def test_sparse_cat_cross_entropy(self): - input = torch.randn(3, 5) - target = torch.empty(3, dtype=torch.long).random_(5) - loss_fun = torch.nn.CrossEntropyLoss(reduction='mean') - loss = loss_fun(input, target) + # create in tinygrad + input = Tensor.randn(3, 5) + target = Tensor.randint((3,), low=0, high=4) + loss = input.sparse_categorical_crossentropy(target) - input_tiny = Tensor(input.detach().numpy()) - target_tiny = Tensor(target.detach().numpy()) - loss_tiny = input_tiny.sparse_categorical_crossentropy(target_tiny) + torch_input = torch.tensor(input.numpy()) + torch_target = torch.tensor(target.numpy(), dtype=torch.long) + torch_loss = torch.nn.CrossEntropyLoss(reduction='mean')(torch_input, torch_target) - np.testing.assert_allclose(loss_tiny.numpy(), loss.detach().numpy(), atol=1e-5, rtol=1e-6) + np.testing.assert_allclose(loss.numpy(), torch_loss.detach().numpy(), atol=1e-5, rtol=1e-6) def test_batchnorm2d(self, training=False): - szs = [4, 8, 16, 32] - for sz in szs: - # create in tinygrad - Tensor.training = training - bn = BatchNorm2d(sz, eps=1e-5, track_running_stats=training) - bn.weight = Tensor.randn(sz) - bn.bias = Tensor.randn(sz) - bn.running_mean = Tensor.randn(sz) - bn.running_var = Tensor.randn(sz) - bn.running_var.numpy()[bn.running_var.numpy() < 0] = 0 + with Tensor.train(training): + szs = [4, 8, 16, 32] + for sz in szs: + # create in tinygrad + bn = BatchNorm2d(sz, eps=1e-5, track_running_stats=training) + bn.weight = Tensor.randn(sz) + bn.bias = Tensor.randn(sz) + bn.running_mean = Tensor.randn(sz) + bn.running_var = Tensor.randn(sz) + bn.running_var.numpy()[bn.running_var.numpy() < 0] = 0 - # create in torch - with torch.no_grad(): - tbn = torch.nn.BatchNorm2d(sz).eval() - tbn.training = training - tbn.weight[:] = torch.tensor(bn.weight.numpy()) - tbn.bias[:] = torch.tensor(bn.bias.numpy()) - tbn.running_mean[:] = torch.tensor(bn.running_mean.numpy()) - tbn.running_var[:] = torch.tensor(bn.running_var.numpy()) + # create in torch + with torch.no_grad(): + tbn = torch.nn.BatchNorm2d(sz).eval() + tbn.training = training + tbn.weight[:] = torch.tensor(bn.weight.numpy()) + tbn.bias[:] = torch.tensor(bn.bias.numpy()) + tbn.running_mean[:] = torch.tensor(bn.running_mean.numpy()) + tbn.running_var[:] = torch.tensor(bn.running_var.numpy()) - np.testing.assert_allclose(bn.running_mean.numpy(), tbn.running_mean.detach().numpy(), rtol=1e-5, atol=1e-6) - np.testing.assert_allclose(bn.running_var.numpy(), tbn.running_var.detach().numpy(), rtol=1e-5, atol=1e-6) + np.testing.assert_allclose(bn.running_mean.numpy(), tbn.running_mean.detach().numpy(), rtol=1e-5, atol=1e-6) + np.testing.assert_allclose(bn.running_var.numpy(), tbn.running_var.detach().numpy(), rtol=1e-5, atol=1e-6) - # trial - inn = Tensor.randn(2, sz, 3, 3) + # trial + inn = Tensor.randn(2, sz, 3, 3) - # in tinygrad - outt = bn(inn) + # in tinygrad + outt = bn(inn) - # in torch - toutt = tbn(torch.tensor(inn.numpy())) + # in torch + toutt = tbn(torch.tensor(inn.numpy())) - # close - np.testing.assert_allclose(outt.numpy(), toutt.detach().numpy(), rtol=5e-4, atol=1e-6) - - np.testing.assert_allclose(bn.running_mean.numpy(), tbn.running_mean.detach().numpy(), rtol=1e-5, atol=1e-6) - - np.testing.assert_allclose(bn.running_var.numpy(), tbn.running_var.detach().numpy(), rtol=1e-5, atol=1e-6) + # close + np.testing.assert_allclose(outt.numpy(), toutt.detach().numpy(), rtol=5e-4, atol=1e-6) + np.testing.assert_allclose(bn.running_mean.numpy(), tbn.running_mean.detach().numpy(), rtol=1e-5, atol=1e-6) + np.testing.assert_allclose(bn.running_var.numpy(), tbn.running_var.detach().numpy(), rtol=1e-5, atol=1e-6) def test_batchnorm2d_training(self): self.test_batchnorm2d(True) def test_linear(self): - def _test_linear(x): - + def _test_linear(x, in_dim, out_dim): # create in tinygrad model = Linear(in_dim, out_dim) z = model(x) @@ -84,8 +80,8 @@ class TestNN(unittest.TestCase): np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5) BS, T, in_dim, out_dim = 4, 2, 8, 16 - _test_linear(Tensor.randn(BS, in_dim)) - _test_linear(Tensor.randn(BS, T, in_dim)) # test with more dims + _test_linear(Tensor.randn(BS, in_dim), in_dim, out_dim) + _test_linear(Tensor.randn(BS, T, in_dim), in_dim, out_dim) # test with more dims def test_conv1d(self): BS, C1, W = 4, 16, 224//4 @@ -316,14 +312,14 @@ class TestNN(unittest.TestCase): # test x = Tensor(np.random.randint(0, vocab_size, (B, T))) z = layer(x) - torch_x = torch.tensor(x.numpy().astype(np.int32)) + torch_x = torch.tensor(x.numpy()) torch_z = torch_layer(torch_x) np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8) # test with empty input length x = Tensor(np.random.randint(0, vocab_size, (B, 0))) z = layer(x) - torch_x = torch.tensor(x.numpy().astype(np.int32)) + torch_x = torch.tensor(x.numpy()) torch_z = torch_layer(torch_x) np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8) @@ -333,9 +329,9 @@ class TestNN(unittest.TestCase): return layer(x).realize() for _ in range(3): - x = Tensor(np.random.randint(0, vocab_size, (B, T)).astype(np.float32)) + x = Tensor(np.random.randint(0, vocab_size, (B, T))) z = layer_jit(x) - torch_x = torch.tensor(x.numpy().astype(np.int32)) + torch_x = torch.tensor(x.numpy()) torch_z = torch_layer(torch_x) np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8)