mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
add test for checking const gradients (#9598)
This commit is contained in:
@@ -152,6 +152,27 @@ class TestTinygrad(unittest.TestCase):
|
||||
for x,y in zip(test_tinygrad(), test_pytorch()):
|
||||
np.testing.assert_allclose(x, y, atol=1e-5, rtol=1e-6)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_const_backward_pass(self):
|
||||
init = 3.5
|
||||
|
||||
def test_pytorch():
|
||||
w1 = torch.tensor(init, requires_grad=True)
|
||||
w2 = torch.tensor(init, requires_grad=True)
|
||||
out = w1.add(w2)
|
||||
out.backward()
|
||||
return w1.grad, w2.grad
|
||||
|
||||
def test_tinygrad():
|
||||
w1 = Tensor(init, requires_grad=True)
|
||||
w2 = Tensor(init, requires_grad=True)
|
||||
out = w1.add(w2)
|
||||
out.backward()
|
||||
return w1.grad.numpy(), w2.grad.numpy()
|
||||
|
||||
for x, y in zip(test_tinygrad(), test_pytorch()):
|
||||
np.testing.assert_allclose(x, y, atol=1e-5)
|
||||
|
||||
def test_nograd(self):
|
||||
x = Tensor(x_init, requires_grad=False)
|
||||
m = Tensor(m_init, requires_grad=False)
|
||||
|
||||
Reference in New Issue
Block a user