mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
gradcheck now returns only a bool, refactoring of test_gradcheck
This commit is contained in:
@@ -52,36 +52,17 @@ class TestTinygrad(unittest.TestCase):
|
||||
np.testing.assert_allclose(PJ, NJ, atol = 1e-5)
|
||||
|
||||
def test_gradcheck(self):
|
||||
class TinyModel:
|
||||
def __init__(self, weights_init):
|
||||
self.l1 = Tensor(weights_init)
|
||||
def forward(self, x):
|
||||
return x.dot(self.l1).relu().logsoftmax()
|
||||
W = np.random.RandomState(1337).random((10, 5))
|
||||
x = np.random.RandomState(7331).random((1, 10)) - 0.5
|
||||
|
||||
class TorchModel(torch.nn.Module):
|
||||
def __init__(self, weights_init):
|
||||
super(TorchModel, self).__init__()
|
||||
self.l1 = torch.nn.Linear(*weights_init.shape, bias = False)
|
||||
self.l1.weight = torch.nn.Parameter(torch.tensor(weights_init.T, requires_grad = True))
|
||||
def forward(self, x):
|
||||
return torch.nn.functional.log_softmax(self.l1(x).relu(), dim=1)
|
||||
tiny_x = Tensor(x)
|
||||
tiny_W = Tensor(W)
|
||||
tiny_func = lambda x: x.dot(tiny_W).relu().logsoftmax()
|
||||
|
||||
layer_weights = np.random.RandomState(1337).random((10, 5))
|
||||
input_data = np.random.RandomState(7331).random((1, 10)) - 0.5
|
||||
|
||||
torch_input = torch.tensor(input_data, requires_grad = True)
|
||||
torch_model = TorchModel(layer_weights)
|
||||
|
||||
tiny_model = TinyModel(layer_weights)
|
||||
tiny_input = Tensor(input_data)
|
||||
|
||||
# test gradcheck
|
||||
gradcheck_test, _, _ = gradcheck(tiny_model.forward, tiny_input)
|
||||
self.assertTrue(gradcheck_test)
|
||||
self.assertTrue(gradcheck(tiny_func, tiny_x))
|
||||
|
||||
# coarse approx. since a "big" eps and the non-linearities of the model
|
||||
gradcheck_test, j, nj = gradcheck(tiny_model.forward, tiny_input, eps = 0.1)
|
||||
self.assertFalse(gradcheck_test)
|
||||
self.assertFalse(gradcheck(tiny_func, tiny_x, eps = 0.1))
|
||||
|
||||
def test_conv2d(self):
|
||||
x = torch.randn((5,2,10,7), requires_grad=True)
|
||||
|
||||
@@ -78,11 +78,8 @@ def gradcheck(model, input, eps = 1e-06, atol = 1e-5, rtol = 0.001):
|
||||
|
||||
returns:
|
||||
test_passed : Bool, whether the test passed
|
||||
J : Analytical Jacobian
|
||||
NJ : Finite-Difference approx. Jacobian
|
||||
"""
|
||||
NJ = numerical_jacobian(model, input, eps)
|
||||
J = jacobian(model, input)
|
||||
|
||||
test_passed = np.allclose(J, NJ, atol=atol, rtol=rtol)
|
||||
return test_passed, J, NJ
|
||||
return np.allclose(J, NJ, atol=atol, rtol=rtol)
|
||||
|
||||
Reference in New Issue
Block a user