Consistent testing (#137)

* Consistent GPU classes

Convert the existing GPU classes into one standard format.

Remove duplicated functions in `test_mnist` and create a TestMNISTGPU
class. This reduces line count and ensures consistency.

Use `@unittest.skipUnless(GPU, "Requires GPU")` instead of `if GPU:` to
skip GPU testing. This will ensure that skipped tests are displayed
accordingly in the pytest output.

* Optim Testing now supports GPU

* Tensor testing now supports GPU

jacobian and gradcheck auto skipped until GPU float64 support added.

* GPU support for custom constructor methods

* Remove GPU flag from Model constructors

It was requested that the `gpu` kwarg be removed from the model
constructor. GPU conversion is now handled in the train function.

This also required the conversion of Optimizer parameters as they are
constructed prior to execution of the `train` function and are dependant
on the model GPU state.

* Fix typo: float32->float64

* Clean `get_parameters` utility

Just a quick refactor w/ the new support for optimizers.

* Remove GPU kwarg from TinyNet

Remove `gpu` kwarg from tiny net to match test_mnist `train` function.
This commit is contained in:
Liam
2020-12-09 11:25:27 +01:00
committed by GitHub
parent 34b38dd4d0
commit 89d0ff6989
5 changed files with 64 additions and 63 deletions

View File

@@ -1,7 +1,7 @@
import numpy as np
import torch
import unittest
from tinygrad.tensor import Tensor
from tinygrad.tensor import Tensor, GPU
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
x_init = np.random.randn(1,3).astype(np.float32)
@@ -9,16 +9,18 @@ W_init = np.random.randn(3,3).astype(np.float32)
m_init = np.random.randn(1,3).astype(np.float32)
class TestTinygrad(unittest.TestCase):
gpu = False
def test_backward_pass(self):
def test_tinygrad():
x = Tensor(x_init)
W = Tensor(W_init)
m = Tensor(m_init)
x = Tensor(x_init, gpu=self.gpu)
W = Tensor(W_init, gpu=self.gpu)
m = Tensor(m_init, gpu=self.gpu)
out = x.dot(W).relu()
out = out.logsoftmax()
out = out.mul(m).add(m).sum()
out.backward()
return out.data, x.grad.data, W.grad.data
return out.cpu().data, x.grad.cpu().data, W.grad.cpu().data
def test_pytorch():
x = torch.tensor(x_init, requires_grad=True)
@@ -42,8 +44,8 @@ class TestTinygrad(unittest.TestCase):
torch_func = lambda x: torch.nn.functional.log_softmax(x.matmul(torch_W).relu(), dim=1)
PJ = torch.autograd.functional.jacobian(torch_func, torch_x).squeeze().numpy()
tiny_x = Tensor(x)
tiny_W = Tensor(W)
tiny_x = Tensor(x, gpu=self.gpu)
tiny_W = Tensor(W, gpu=self.gpu)
tiny_func = lambda x: x.dot(tiny_W).relu().logsoftmax()
J = jacobian(tiny_func, tiny_x)
NJ = numerical_jacobian(tiny_func, tiny_x)
@@ -55,8 +57,8 @@ class TestTinygrad(unittest.TestCase):
W = np.random.RandomState(1337).random((10, 5))
x = np.random.RandomState(7331).random((1, 10)) - 0.5
tiny_x = Tensor(x)
tiny_W = Tensor(W)
tiny_x = Tensor(x, gpu=self.gpu)
tiny_W = Tensor(W, gpu=self.gpu)
tiny_func = lambda x: x.dot(tiny_W).relu().logsoftmax()
self.assertTrue(gradcheck(tiny_func, tiny_x))
@@ -64,5 +66,17 @@ class TestTinygrad(unittest.TestCase):
# coarse approx. since a "big" eps and the non-linearities of the model
self.assertFalse(gradcheck(tiny_func, tiny_x, eps = 0.1))
@unittest.skipUnless(GPU, "Requires GPU")
class TestTinygradGPU(TestTinygrad):
gpu = True
@unittest.skip("float64 not supported on GPU")
def test_jacobian(self): pass
@unittest.skip("float64 not supported on GPU")
def test_gradcheck(self): pass
if __name__ == '__main__':
unittest.main()