Files
tinygrad/test/test_tensor.py
George Hotz 9ac1ad40d6 Add GPU Support! (do not merge yet) (#41)
* copy tensors to and from gpu

* add on GPU

* adding works

* we stick shapes in

* works on cpu and gpu

* test changes, not passing yet

* something else

* op tests pass

* add, mean, and sum have working forward/backward

* mul ops test

* no gpu support, no problem

* test pass, clean up later

* gpu cleanup

* cleanup test ops, don't let div fail

* revert more

* aimpler dispatcher

* clean up grad

* GPU and

* grad is a Tensor now

* gate test on GPU

* cleanups

* late loading gpu

* GPU as input option

* last cleanups
2020-11-01 07:00:49 -08:00

69 lines
2.2 KiB
Python

import numpy as np
import torch
import unittest
from tinygrad.tensor import Tensor
from tinygrad.gradcheck import numerical_jacobian, jacobian, gradcheck
x_init = np.random.randn(1,3).astype(np.float32)
W_init = np.random.randn(3,3).astype(np.float32)
m_init = np.random.randn(1,3).astype(np.float32)
class TestTinygrad(unittest.TestCase):
def test_backward_pass(self):
def test_tinygrad():
x = Tensor(x_init)
W = Tensor(W_init)
m = Tensor(m_init)
out = x.dot(W).relu()
out = out.logsoftmax()
out = out.mul(m).add(m).sum()
out.backward()
return out.data, x.grad.data, W.grad.data
def test_pytorch():
x = torch.tensor(x_init, requires_grad=True)
W = torch.tensor(W_init, requires_grad=True)
m = torch.tensor(m_init)
out = x.matmul(W).relu()
out = torch.nn.functional.log_softmax(out, dim=1)
out = out.mul(m).add(m).sum()
out.backward()
return out.detach().numpy(), x.grad, W.grad
for x,y in zip(test_tinygrad(), test_pytorch()):
np.testing.assert_allclose(x, y, atol=1e-5)
def test_jacobian(self):
W = np.random.RandomState(1337).random((10, 5))
x = np.random.RandomState(7331).random((1, 10)) - 0.5
torch_x = torch.tensor(x, requires_grad=True)
torch_W = torch.tensor(W, requires_grad=True)
torch_func = lambda x: torch.nn.functional.log_softmax(x.matmul(torch_W).relu(), dim=1)
PJ = torch.autograd.functional.jacobian(torch_func, torch_x).squeeze().numpy()
tiny_x = Tensor(x)
tiny_W = Tensor(W)
tiny_func = lambda x: x.dot(tiny_W).relu().logsoftmax()
J = jacobian(tiny_func, tiny_x)
NJ = numerical_jacobian(tiny_func, tiny_x)
np.testing.assert_allclose(PJ, J, atol = 1e-5)
np.testing.assert_allclose(PJ, NJ, atol = 1e-5)
def test_gradcheck(self):
W = np.random.RandomState(1337).random((10, 5))
x = np.random.RandomState(7331).random((1, 10)) - 0.5
tiny_x = Tensor(x)
tiny_W = Tensor(W)
tiny_func = lambda x: x.dot(tiny_W).relu().logsoftmax()
self.assertTrue(gradcheck(tiny_func, tiny_x))
# coarse approx. since a "big" eps and the non-linearities of the model
self.assertFalse(gradcheck(tiny_func, tiny_x, eps = 0.1))
if __name__ == '__main__':
unittest.main()