extracting jacobian and test_jacobian

This commit is contained in:
0xNaN
2020-10-21 22:08:35 +02:00
parent 93bc3c22a0
commit 1561d3b9c0
2 changed files with 55 additions and 31 deletions

View File

@@ -33,6 +33,26 @@ class TestTinygrad(unittest.TestCase):
for x,y in zip(test_tinygrad(), test_pytorch()):
np.testing.assert_allclose(x, y, atol=1e-5)
def test_jacobian(self):
W = np.random.RandomState(1337).random((10, 5))
x = np.random.RandomState(7331).random((1, 10)) - 0.5
torch_x = torch.tensor(x, requires_grad=True)
torch_W = torch.tensor(W, requires_grad=True)
torch_func = lambda x: torch.nn.functional.log_softmax(x.matmul(torch_W).relu(), dim=1)
torch_out = torch_func(torch_x)
# autograd.grad computes the _sum_ of gradients of given tensors
J_sum = torch.autograd.grad(list(torch_out[0]), torch_x)[0].squeeze().numpy()
tiny_x = Tensor(x)
tiny_W = Tensor(W)
tiny_func = lambda x: x.dot(tiny_W).relu().logsoftmax()
NJ = numerical_jacobian(tiny_func, tiny_x)
NJ_sum = NJ.sum(axis = -1)
np.testing.assert_allclose(J_sum, NJ_sum, atol = 1e-5)
def test_gradcheck(self):
class TinyModel:
def __init__(self, weights_init):
@@ -53,25 +73,16 @@ class TestTinygrad(unittest.TestCase):
torch_input = torch.tensor(input_data, requires_grad = True)
torch_model = TorchModel(layer_weights)
torch_out = torch_model(torch_input)
# autograd.grad computes the _sum_ of gradients of given tensors
J_sum = torch.autograd.grad(list(torch_out[0]), torch_input)[0].squeeze().numpy()
tiny_model = TinyModel(layer_weights)
tiny_input = Tensor(input_data)
tiny_out = tiny_model.forward(tiny_input)
NJ = numerical_jacobian(tiny_model, tiny_input)
NJ_sum = NJ.sum(axis = -1)
# checking the numerical approx. of J is close to the one provided autograd
np.testing.assert_allclose(J_sum, NJ_sum, atol = 1e-5)
# test gradcheck
gradcheck_test, _, _ = gradcheck(tiny_model, tiny_input)
gradcheck_test, _, _ = gradcheck(tiny_model.forward, tiny_input)
self.assertTrue(gradcheck_test)
# coarse approx. since a "big" eps and the non-linearities of the model
gradcheck_test, j, nj = gradcheck(tiny_model, tiny_input, eps = 0.1)
gradcheck_test, j, nj = gradcheck(tiny_model.forward, tiny_input, eps = 0.1)
self.assertFalse(gradcheck_test)
def test_conv2d(self):

View File

@@ -1,6 +1,33 @@
import numpy as np
from tinygrad.tensor import Tensor
def jacobian(model, input):
"""
Compute the (analytical) Jacobian of model w.r.t. input.
model : A tinygrad model
input : An input
returns:
J : Jacobian
"""
output = model(input)
ji = input.data.reshape(-1).shape[-1]
jo = output.data.reshape(-1).shape[-1]
J = np.zeros((ji, jo))
for o in range(jo):
# tinygrad doesn't support slicing, tiny-hack to select
# the needed scalar an backpropagate only through it
o_scalar = Tensor(mask_like(output.data, o, 1.)).mul(output).sum()
o_scalar.backward()
for i, grad in enumerate(input.grad.reshape(-1)):
J[i][o] = grad
return J
def mask_like(like, mask_inx, mask_value = 1.0):
mask = np.zeros_like(like).reshape(-1)
mask[mask_inx] = mask_value
@@ -21,7 +48,7 @@ def numerical_jacobian(model, input, eps = 1e-6):
[1]: https://timvieira.github.io/blog/post/2017/04/21/how-to-test-gradient-implementations/
"""
output = model.forward(input)
output = model(input)
ji = input.data.reshape(-1).shape[-1]
jo = output.data.reshape(-1).shape[-1]
@@ -30,9 +57,9 @@ def numerical_jacobian(model, input, eps = 1e-6):
for i in range(ji):
for o in range(jo):
eps_pertub = mask_like(input.data, i, mask_value = eps)
output_perturb_add = model.forward(Tensor(input.data + eps_pertub)).data.reshape(-1)[o]
output_perturb_sub = model.forward(Tensor(input.data - eps_pertub)).data.reshape(-1)[o]
eps_perturb = mask_like(input.data, i, mask_value = eps)
output_perturb_add = model(Tensor(input.data + eps_perturb)).data.reshape(-1)[o]
output_perturb_sub = model(Tensor(input.data - eps_perturb)).data.reshape(-1)[o]
grad_approx = ((output_perturb_add) - (output_perturb_sub)) / (2*eps)
@@ -42,7 +69,7 @@ def numerical_jacobian(model, input, eps = 1e-6):
def gradcheck(model, input, eps = 1e-06, atol = 1e-5, rtol = 0.001):
"""
Checks whether the numerical approx. of the Jacobian of model w.r.t input is close to the
analitical one (computed through .backward())
analytical one.
model : A tinygrad model
input : An input
@@ -55,21 +82,7 @@ def gradcheck(model, input, eps = 1e-06, atol = 1e-5, rtol = 0.001):
NJ : Finite-Difference approx. Jacobian
"""
NJ = numerical_jacobian(model, input, eps)
output = model.forward(input)
ji = input.data.reshape(-1).shape[-1]
jo = output.data.reshape(-1).shape[-1]
J = np.zeros((ji, jo))
for o in range(jo):
# tinygrad doesn't support slicing, tiny-hack to select
# the needed scalar an backpropagate only through it
o_scalar = Tensor(mask_like(output.data, o, 1.)).mul(output).sum()
o_scalar.backward()
for i, grad in enumerate(input.grad.reshape(-1)):
J[i][o] = grad
J = jacobian(model, input)
test_passed = np.allclose(J, NJ, atol=atol, rtol=rtol)
return test_passed, J, NJ