Files
tinygrad/test/test_nn.py
Liam bcf1518309 All devices are equal! (#196)
* Update all devices to be tested

ANE, CPU and OCL all now support all tests.

However tests are not currently passing on GPU and I cannot test on CPU.

Failing GPU test are not an issue caused by this update. Tests have not
been passing due to a missing "six" required installation.

OpenCL Tests have not been run since commit: 1a1c63a08b

devices have 3 types and are handle by a new DeviceTypes enum. (The goal
is to revert to Tensor.<type>, but this current setup allows for keyword
argument defaults: `device=DeviceType.CPU`)

All references to Tensor.GPU/CPU/ANE as been converted to the
corresponding `DeviceTypes` enum.

Refactor of the conversion code to allow for any device to any device
conversion.

* Add six dependency in requirements.txt

* Resolve failure to run tests

Move six into gpu required installs. Remove six from standard
installation.

* Remove repeated data conversion

* Refactor method names

Also reduce code with .to and .to_

* Dynamic device handlers

* Refactor DeviceTypes -> Device

* Add mem copy profiling back

* test_backward_pass_diamond_model passing

* Resolve Sum issue on GPU

* Revert batchnorm2d tests

* Update README with upadated API

* ANE testing with

* Last minute line gains
2020-12-15 23:44:08 -08:00

80 lines
2.2 KiB
Python

#!/usr/bin/env python
import unittest
import numpy as np
from tinygrad.tensor import GPU, Device
from tinygrad.nn import *
from extra.utils import get_parameters
import torch
from .config import ANE
class TestNN(unittest.TestCase):
device = Device.CPU
def test_batchnorm2d(self, training=False):
sz = 4
# create in tinygrad
bn = BatchNorm2D(sz, eps=1e-5, training=training, track_running_stats=training)
bn.weight = Tensor.randn(sz)
bn.bias = Tensor.randn(sz)
bn.running_mean = Tensor.randn(sz)
bn.running_var = Tensor.randn(sz)
bn.running_var.data[bn.running_var.data < 0] = 0
# create in torch
with torch.no_grad():
tbn = torch.nn.BatchNorm2d(sz).eval()
tbn.training = training
tbn.weight[:] = torch.tensor(bn.weight.data)
tbn.bias[:] = torch.tensor(bn.bias.data)
tbn.running_mean[:] = torch.tensor(bn.running_mean.data)
tbn.running_var[:] = torch.tensor(bn.running_var.data)
np.testing.assert_allclose(bn.running_mean.data, tbn.running_mean.detach().numpy(), rtol=1e-5)
np.testing.assert_allclose(bn.running_var.data, tbn.running_var.detach().numpy(), rtol=1e-5)
# trial
inn = Tensor.randn(2, sz, 3, 3, device=self.device)
# in tinygrad
outt = bn(inn)
# in torch
toutt = tbn(torch.tensor(inn.cpu().data))
# close
np.testing.assert_allclose(outt.data, toutt.detach().numpy(), rtol=5e-5)
np.testing.assert_allclose(bn.running_mean.data, tbn.running_mean.detach().numpy(), rtol=1e-5)
# TODO: this is failing
#np.testing.assert_allclose(bn.running_var.data, tbn.running_var.detach().numpy(), rtol=1e-5)
def test_batchnorm2d_training(self):
self.test_batchnorm2d(True)
@unittest.skipUnless(GPU, "Requires GPU")
class TestNNGPU(TestNN):
device = Device.GPU
@unittest.skip("Tests not added")
def test_batchnorm2d(self): pass
@unittest.skip("Tests not added")
def test_batchnorm2d_training(self): pass
@unittest.skipUnless(ANE, "Requires ANE")
class TestNNANE(TestNN):
device=Device.ANE
@unittest.skip("Tests not added")
def test_batchnorm2d(self): pass
@unittest.skip("Tests not added")
def test_batchnorm2d_training(self): pass
if __name__ == '__main__':
unittest.main()