mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
feat: allow passing gradient to .backward() to compute vjp (#5771)
* feat: allow passing gradient to .backward() to compute vjp * fix * refactor * fix trailing whitespace
This commit is contained in:
committed by
GitHub
parent
e0e7293b0a
commit
d0fd84e617
@@ -16,6 +16,7 @@ U_init = np.random.randn(3,3).astype(np.float32)
|
||||
V_init = np.random.randn(3,3).astype(np.float32)
|
||||
W_init = np.random.randn(3,3).astype(np.float32)
|
||||
m_init = np.random.randn(1,3).astype(np.float32)
|
||||
gradient = np.random.randn(1,3).astype(np.float32)
|
||||
|
||||
class TestTinygrad(unittest.TestCase):
|
||||
def test_zerodim_initialization(self):
|
||||
@@ -55,6 +56,31 @@ class TestTinygrad(unittest.TestCase):
|
||||
for x,y in zip(test_tinygrad(), test_pytorch()):
|
||||
np.testing.assert_allclose(x, y, atol=1e-5)
|
||||
|
||||
# passing `gradient` to backward
|
||||
def test_backward_pass_vjp(self):
|
||||
def test_tinygrad():
|
||||
x = Tensor(x_init, requires_grad=True)
|
||||
W = Tensor(W_init, requires_grad=True)
|
||||
m = Tensor(m_init)
|
||||
out = x.dot(W).relu()
|
||||
out = out.log_softmax()
|
||||
out = out.mul(m).add(m)
|
||||
out.backward(Tensor(gradient))
|
||||
return out.numpy(), x.grad.numpy(), W.grad.numpy()
|
||||
|
||||
def test_pytorch():
|
||||
x = torch.tensor(x_init, requires_grad=True)
|
||||
W = torch.tensor(W_init, requires_grad=True)
|
||||
m = torch.tensor(m_init)
|
||||
out = x.matmul(W).relu()
|
||||
out = torch.nn.functional.log_softmax(out, dim=1)
|
||||
out = out.mul(m).add(m)
|
||||
out.backward(torch.tensor(gradient))
|
||||
return out.detach().numpy(), x.grad, W.grad
|
||||
|
||||
for x,y in zip(test_tinygrad(), test_pytorch()):
|
||||
np.testing.assert_allclose(x, y, atol=1e-5)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "this test uses more than 8 bufs which breaks webgpu") #TODO: remove after #1461
|
||||
def test_backward_pass_diamond_model(self):
|
||||
def test_tinygrad():
|
||||
|
||||
Reference in New Issue
Block a user