mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
extracting jacobian and test_jacobian
This commit is contained in:
@@ -33,6 +33,26 @@ class TestTinygrad(unittest.TestCase):
|
||||
for x,y in zip(test_tinygrad(), test_pytorch()):
|
||||
np.testing.assert_allclose(x, y, atol=1e-5)
|
||||
|
||||
def test_jacobian(self):
|
||||
W = np.random.RandomState(1337).random((10, 5))
|
||||
x = np.random.RandomState(7331).random((1, 10)) - 0.5
|
||||
|
||||
torch_x = torch.tensor(x, requires_grad=True)
|
||||
torch_W = torch.tensor(W, requires_grad=True)
|
||||
torch_func = lambda x: torch.nn.functional.log_softmax(x.matmul(torch_W).relu(), dim=1)
|
||||
torch_out = torch_func(torch_x)
|
||||
|
||||
# autograd.grad computes the _sum_ of gradients of given tensors
|
||||
J_sum = torch.autograd.grad(list(torch_out[0]), torch_x)[0].squeeze().numpy()
|
||||
|
||||
tiny_x = Tensor(x)
|
||||
tiny_W = Tensor(W)
|
||||
tiny_func = lambda x: x.dot(tiny_W).relu().logsoftmax()
|
||||
NJ = numerical_jacobian(tiny_func, tiny_x)
|
||||
NJ_sum = NJ.sum(axis = -1)
|
||||
|
||||
np.testing.assert_allclose(J_sum, NJ_sum, atol = 1e-5)
|
||||
|
||||
def test_gradcheck(self):
|
||||
class TinyModel:
|
||||
def __init__(self, weights_init):
|
||||
@@ -53,25 +73,16 @@ class TestTinygrad(unittest.TestCase):
|
||||
|
||||
torch_input = torch.tensor(input_data, requires_grad = True)
|
||||
torch_model = TorchModel(layer_weights)
|
||||
torch_out = torch_model(torch_input)
|
||||
# autograd.grad computes the _sum_ of gradients of given tensors
|
||||
J_sum = torch.autograd.grad(list(torch_out[0]), torch_input)[0].squeeze().numpy()
|
||||
|
||||
tiny_model = TinyModel(layer_weights)
|
||||
tiny_input = Tensor(input_data)
|
||||
tiny_out = tiny_model.forward(tiny_input)
|
||||
NJ = numerical_jacobian(tiny_model, tiny_input)
|
||||
NJ_sum = NJ.sum(axis = -1)
|
||||
|
||||
# checking the numerical approx. of J is close to the one provided autograd
|
||||
np.testing.assert_allclose(J_sum, NJ_sum, atol = 1e-5)
|
||||
|
||||
# test gradcheck
|
||||
gradcheck_test, _, _ = gradcheck(tiny_model, tiny_input)
|
||||
gradcheck_test, _, _ = gradcheck(tiny_model.forward, tiny_input)
|
||||
self.assertTrue(gradcheck_test)
|
||||
|
||||
# coarse approx. since a "big" eps and the non-linearities of the model
|
||||
gradcheck_test, j, nj = gradcheck(tiny_model, tiny_input, eps = 0.1)
|
||||
gradcheck_test, j, nj = gradcheck(tiny_model.forward, tiny_input, eps = 0.1)
|
||||
self.assertFalse(gradcheck_test)
|
||||
|
||||
def test_conv2d(self):
|
||||
|
||||
@@ -1,6 +1,33 @@
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
def jacobian(model, input):
|
||||
"""
|
||||
Compute the (analytical) Jacobian of model w.r.t. input.
|
||||
|
||||
model : A tinygrad model
|
||||
input : An input
|
||||
|
||||
returns:
|
||||
|
||||
J : Jacobian
|
||||
"""
|
||||
output = model(input)
|
||||
|
||||
ji = input.data.reshape(-1).shape[-1]
|
||||
jo = output.data.reshape(-1).shape[-1]
|
||||
J = np.zeros((ji, jo))
|
||||
|
||||
for o in range(jo):
|
||||
# tinygrad doesn't support slicing, tiny-hack to select
|
||||
# the needed scalar an backpropagate only through it
|
||||
o_scalar = Tensor(mask_like(output.data, o, 1.)).mul(output).sum()
|
||||
o_scalar.backward()
|
||||
|
||||
for i, grad in enumerate(input.grad.reshape(-1)):
|
||||
J[i][o] = grad
|
||||
return J
|
||||
|
||||
def mask_like(like, mask_inx, mask_value = 1.0):
|
||||
mask = np.zeros_like(like).reshape(-1)
|
||||
mask[mask_inx] = mask_value
|
||||
@@ -21,7 +48,7 @@ def numerical_jacobian(model, input, eps = 1e-6):
|
||||
|
||||
[1]: https://timvieira.github.io/blog/post/2017/04/21/how-to-test-gradient-implementations/
|
||||
"""
|
||||
output = model.forward(input)
|
||||
output = model(input)
|
||||
|
||||
ji = input.data.reshape(-1).shape[-1]
|
||||
jo = output.data.reshape(-1).shape[-1]
|
||||
@@ -30,9 +57,9 @@ def numerical_jacobian(model, input, eps = 1e-6):
|
||||
for i in range(ji):
|
||||
for o in range(jo):
|
||||
|
||||
eps_pertub = mask_like(input.data, i, mask_value = eps)
|
||||
output_perturb_add = model.forward(Tensor(input.data + eps_pertub)).data.reshape(-1)[o]
|
||||
output_perturb_sub = model.forward(Tensor(input.data - eps_pertub)).data.reshape(-1)[o]
|
||||
eps_perturb = mask_like(input.data, i, mask_value = eps)
|
||||
output_perturb_add = model(Tensor(input.data + eps_perturb)).data.reshape(-1)[o]
|
||||
output_perturb_sub = model(Tensor(input.data - eps_perturb)).data.reshape(-1)[o]
|
||||
|
||||
grad_approx = ((output_perturb_add) - (output_perturb_sub)) / (2*eps)
|
||||
|
||||
@@ -42,7 +69,7 @@ def numerical_jacobian(model, input, eps = 1e-6):
|
||||
def gradcheck(model, input, eps = 1e-06, atol = 1e-5, rtol = 0.001):
|
||||
"""
|
||||
Checks whether the numerical approx. of the Jacobian of model w.r.t input is close to the
|
||||
analitical one (computed through .backward())
|
||||
analytical one.
|
||||
|
||||
model : A tinygrad model
|
||||
input : An input
|
||||
@@ -55,21 +82,7 @@ def gradcheck(model, input, eps = 1e-06, atol = 1e-5, rtol = 0.001):
|
||||
NJ : Finite-Difference approx. Jacobian
|
||||
"""
|
||||
NJ = numerical_jacobian(model, input, eps)
|
||||
|
||||
output = model.forward(input)
|
||||
|
||||
ji = input.data.reshape(-1).shape[-1]
|
||||
jo = output.data.reshape(-1).shape[-1]
|
||||
J = np.zeros((ji, jo))
|
||||
|
||||
for o in range(jo):
|
||||
# tinygrad doesn't support slicing, tiny-hack to select
|
||||
# the needed scalar an backpropagate only through it
|
||||
o_scalar = Tensor(mask_like(output.data, o, 1.)).mul(output).sum()
|
||||
o_scalar.backward()
|
||||
|
||||
for i, grad in enumerate(input.grad.reshape(-1)):
|
||||
J[i][o] = grad
|
||||
J = jacobian(model, input)
|
||||
|
||||
test_passed = np.allclose(J, NJ, atol=atol, rtol=rtol)
|
||||
return test_passed, J, NJ
|
||||
|
||||
Reference in New Issue
Block a user