mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
* Consistent GPU classes Convert the existing GPU classes into one standard format. Remove duplicated functions in `test_mnist` and create a TestMNISTGPU class. This reduces line count and ensures consistency. Use `@unittest.skipUnless(GPU, "Requires GPU")` instead of `if GPU:` to skip GPU testing. This will ensure that skipped tests are displayed accordingly in the pytest output. * Optim Testing now supports GPU * Tensor testing now supports GPU jacobian and gradcheck auto skipped until GPU float64 support added. * GPU support for custom constructor methods * Remove GPU flag from Model constructors It was requested that the `gpu` kwarg be removed from the model constructor. GPU conversion is now handled in the train function. This also required the conversion of Optimizer parameters as they are constructed prior to execution of the `train` function and are dependant on the model GPU state. * Fix typo: float32->float64 * Clean `get_parameters` utility Just a quick refactor w/ the new support for optimizers. * Remove GPU kwarg from TinyNet Remove `gpu` kwarg from tiny net to match test_mnist `train` function.
83 lines
2.6 KiB
Python
83 lines
2.6 KiB
Python
import numpy as np
|
|
import torch
|
|
import unittest
|
|
from tinygrad.tensor import Tensor, GPU
|
|
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
|
|
|
|
x_init = np.random.randn(1,3).astype(np.float32)
|
|
W_init = np.random.randn(3,3).astype(np.float32)
|
|
m_init = np.random.randn(1,3).astype(np.float32)
|
|
|
|
class TestTinygrad(unittest.TestCase):
|
|
gpu = False
|
|
|
|
def test_backward_pass(self):
|
|
def test_tinygrad():
|
|
x = Tensor(x_init, gpu=self.gpu)
|
|
W = Tensor(W_init, gpu=self.gpu)
|
|
m = Tensor(m_init, gpu=self.gpu)
|
|
out = x.dot(W).relu()
|
|
out = out.logsoftmax()
|
|
out = out.mul(m).add(m).sum()
|
|
out.backward()
|
|
return out.cpu().data, x.grad.cpu().data, W.grad.cpu().data
|
|
|
|
def test_pytorch():
|
|
x = torch.tensor(x_init, requires_grad=True)
|
|
W = torch.tensor(W_init, requires_grad=True)
|
|
m = torch.tensor(m_init)
|
|
out = x.matmul(W).relu()
|
|
out = torch.nn.functional.log_softmax(out, dim=1)
|
|
out = out.mul(m).add(m).sum()
|
|
out.backward()
|
|
return out.detach().numpy(), x.grad, W.grad
|
|
|
|
for x,y in zip(test_tinygrad(), test_pytorch()):
|
|
np.testing.assert_allclose(x, y, atol=1e-5)
|
|
|
|
def test_jacobian(self):
|
|
W = np.random.RandomState(1337).random((10, 5))
|
|
x = np.random.RandomState(7331).random((1, 10)) - 0.5
|
|
|
|
torch_x = torch.tensor(x, requires_grad=True)
|
|
torch_W = torch.tensor(W, requires_grad=True)
|
|
torch_func = lambda x: torch.nn.functional.log_softmax(x.matmul(torch_W).relu(), dim=1)
|
|
PJ = torch.autograd.functional.jacobian(torch_func, torch_x).squeeze().numpy()
|
|
|
|
tiny_x = Tensor(x, gpu=self.gpu)
|
|
tiny_W = Tensor(W, gpu=self.gpu)
|
|
tiny_func = lambda x: x.dot(tiny_W).relu().logsoftmax()
|
|
J = jacobian(tiny_func, tiny_x)
|
|
NJ = numerical_jacobian(tiny_func, tiny_x)
|
|
|
|
np.testing.assert_allclose(PJ, J, atol = 1e-5)
|
|
np.testing.assert_allclose(PJ, NJ, atol = 1e-5)
|
|
|
|
def test_gradcheck(self):
|
|
W = np.random.RandomState(1337).random((10, 5))
|
|
x = np.random.RandomState(7331).random((1, 10)) - 0.5
|
|
|
|
tiny_x = Tensor(x, gpu=self.gpu)
|
|
tiny_W = Tensor(W, gpu=self.gpu)
|
|
tiny_func = lambda x: x.dot(tiny_W).relu().logsoftmax()
|
|
|
|
self.assertTrue(gradcheck(tiny_func, tiny_x))
|
|
|
|
# coarse approx. since a "big" eps and the non-linearities of the model
|
|
self.assertFalse(gradcheck(tiny_func, tiny_x, eps = 0.1))
|
|
|
|
|
|
@unittest.skipUnless(GPU, "Requires GPU")
|
|
class TestTinygradGPU(TestTinygrad):
|
|
gpu = True
|
|
|
|
@unittest.skip("float64 not supported on GPU")
|
|
def test_jacobian(self): pass
|
|
|
|
@unittest.skip("float64 not supported on GPU")
|
|
def test_gradcheck(self): pass
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|