tiny gradcheck

This commit is contained in:
0xNaN
2020-10-21 00:19:33 +02:00
parent e8feaa53d6
commit 93bc3c22a0
2 changed files with 118 additions and 1 deletions

View File

@@ -2,6 +2,7 @@ import numpy as np
import torch
import unittest
from tinygrad.tensor import Tensor, Conv2D
from tinygrad.gradcheck import numerical_jacobian, gradcheck
x_init = np.random.randn(1,3).astype(np.float32)
W_init = np.random.randn(3,3).astype(np.float32)
@@ -32,6 +33,47 @@ class TestTinygrad(unittest.TestCase):
for x,y in zip(test_tinygrad(), test_pytorch()):
np.testing.assert_allclose(x, y, atol=1e-5)
def test_gradcheck(self):
class TinyModel:
def __init__(self, weights_init):
self.l1 = Tensor(weights_init)
def forward(self, x):
return x.dot(self.l1).relu().logsoftmax()
class TorchModel(torch.nn.Module):
def __init__(self, weights_init):
super(TorchModel, self).__init__()
self.l1 = torch.nn.Linear(*weights_init.shape, bias = False)
self.l1.weight = torch.nn.Parameter(torch.tensor(weights_init.T, requires_grad = True))
def forward(self, x):
return torch.nn.functional.log_softmax(self.l1(x).relu(), dim=1)
layer_weights = np.random.RandomState(1337).random((10, 5))
input_data = np.random.RandomState(7331).random((1, 10)) - 0.5
torch_input = torch.tensor(input_data, requires_grad = True)
torch_model = TorchModel(layer_weights)
torch_out = torch_model(torch_input)
# autograd.grad computes the _sum_ of gradients of given tensors
J_sum = torch.autograd.grad(list(torch_out[0]), torch_input)[0].squeeze().numpy()
tiny_model = TinyModel(layer_weights)
tiny_input = Tensor(input_data)
tiny_out = tiny_model.forward(tiny_input)
NJ = numerical_jacobian(tiny_model, tiny_input)
NJ_sum = NJ.sum(axis = -1)
# checking the numerical approx. of J is close to the one provided autograd
np.testing.assert_allclose(J_sum, NJ_sum, atol = 1e-5)
# test gradcheck
gradcheck_test, _, _ = gradcheck(tiny_model, tiny_input)
self.assertTrue(gradcheck_test)
# coarse approx. since a "big" eps and the non-linearities of the model
gradcheck_test, j, nj = gradcheck(tiny_model, tiny_input, eps = 0.1)
self.assertFalse(gradcheck_test)
def test_conv2d(self):
x = torch.randn((5,2,10,7), requires_grad=True)
w = torch.randn((4,2,3,3), requires_grad=True)
@@ -48,7 +90,7 @@ class TestTinygrad(unittest.TestCase):
np.testing.assert_allclose(w.grad, wt.grad, atol=1e-5)
np.testing.assert_allclose(x.grad, xt.grad, atol=1e-5)
if __name__ == '__main__':
unittest.main()

75
tinygrad/gradcheck.py Normal file
View File

@@ -0,0 +1,75 @@
import numpy as np
from tinygrad.tensor import Tensor
def mask_like(like, mask_inx, mask_value = 1.0):
mask = np.zeros_like(like).reshape(-1)
mask[mask_inx] = mask_value
return mask.reshape(like.shape)
def numerical_jacobian(model, input, eps = 1e-6):
"""
Compute the Jacobian through Finite-Difference Approximation.
Somewhat inspired by [1] but not followed closely.
model : A tinygrad model
input : An input
eps : Perturbation step
returns:
NJ : an approx. of the Jacobian
[1]: https://timvieira.github.io/blog/post/2017/04/21/how-to-test-gradient-implementations/
"""
output = model.forward(input)
ji = input.data.reshape(-1).shape[-1]
jo = output.data.reshape(-1).shape[-1]
NJ = np.zeros((ji, jo))
for i in range(ji):
for o in range(jo):
eps_pertub = mask_like(input.data, i, mask_value = eps)
output_perturb_add = model.forward(Tensor(input.data + eps_pertub)).data.reshape(-1)[o]
output_perturb_sub = model.forward(Tensor(input.data - eps_pertub)).data.reshape(-1)[o]
grad_approx = ((output_perturb_add) - (output_perturb_sub)) / (2*eps)
NJ[i,o] = grad_approx
return NJ
def gradcheck(model, input, eps = 1e-06, atol = 1e-5, rtol = 0.001):
"""
Checks whether the numerical approx. of the Jacobian of model w.r.t input is close to the
analitical one (computed through .backward())
model : A tinygrad model
input : An input
eps : Perturbation step
atol, rtol: Params for the numpy.allclose test
returns:
test_passed : Bool, whether the test passed
J : Analytical Jacobian
NJ : Finite-Difference approx. Jacobian
"""
NJ = numerical_jacobian(model, input, eps)
output = model.forward(input)
ji = input.data.reshape(-1).shape[-1]
jo = output.data.reshape(-1).shape[-1]
J = np.zeros((ji, jo))
for o in range(jo):
# tinygrad doesn't support slicing, tiny-hack to select
# the needed scalar an backpropagate only through it
o_scalar = Tensor(mask_like(output.data, o, 1.)).mul(output).sum()
o_scalar.backward()
for i, grad in enumerate(input.grad.reshape(-1)):
J[i][o] = grad
test_passed = np.allclose(J, NJ, atol=atol, rtol=rtol)
return test_passed, J, NJ