feat: allow passing gradient to .backward() to compute vjp (#5771)

* feat: allow passing gradient to .backward() to compute vjp

* fix

* refactor

* fix trailing whitespace
This commit is contained in:
David González Martínez
2024-07-28 20:13:18 +02:00
committed by GitHub
parent e0e7293b0a
commit d0fd84e617
2 changed files with 35 additions and 6 deletions

View File

@@ -16,6 +16,7 @@ U_init = np.random.randn(3,3).astype(np.float32)
V_init = np.random.randn(3,3).astype(np.float32)
W_init = np.random.randn(3,3).astype(np.float32)
m_init = np.random.randn(1,3).astype(np.float32)
gradient = np.random.randn(1,3).astype(np.float32)
class TestTinygrad(unittest.TestCase):
def test_zerodim_initialization(self):
@@ -55,6 +56,31 @@ class TestTinygrad(unittest.TestCase):
for x,y in zip(test_tinygrad(), test_pytorch()):
np.testing.assert_allclose(x, y, atol=1e-5)
# passing `gradient` to backward
def test_backward_pass_vjp(self):
def test_tinygrad():
x = Tensor(x_init, requires_grad=True)
W = Tensor(W_init, requires_grad=True)
m = Tensor(m_init)
out = x.dot(W).relu()
out = out.log_softmax()
out = out.mul(m).add(m)
out.backward(Tensor(gradient))
return out.numpy(), x.grad.numpy(), W.grad.numpy()
def test_pytorch():
x = torch.tensor(x_init, requires_grad=True)
W = torch.tensor(W_init, requires_grad=True)
m = torch.tensor(m_init)
out = x.matmul(W).relu()
out = torch.nn.functional.log_softmax(out, dim=1)
out = out.mul(m).add(m)
out.backward(torch.tensor(gradient))
return out.detach().numpy(), x.grad, W.grad
for x,y in zip(test_tinygrad(), test_pytorch()):
np.testing.assert_allclose(x, y, atol=1e-5)
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "this test uses more than 8 bufs which breaks webgpu") #TODO: remove after #1461
def test_backward_pass_diamond_model(self):
def test_tinygrad():