mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
remove retain_graph in Tensor.backward [pr] (#8835)
not used. gradient accumulation works directly
This commit is contained in:
@@ -328,7 +328,7 @@ class TestNN(unittest.TestCase):
|
||||
|
||||
torch_x = torch.tensor(x.numpy(), requires_grad=True)
|
||||
torch_z = torch_layer(torch_x)
|
||||
torch_z.sum().backward(retain_graph=True)
|
||||
torch_z.sum().backward()
|
||||
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
|
||||
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
|
||||
@@ -354,7 +354,7 @@ class TestNN(unittest.TestCase):
|
||||
|
||||
torch_x = torch.tensor(x.numpy(), requires_grad=True)
|
||||
torch_z = torch_layer(torch_x)
|
||||
torch_z.sum().backward(retain_graph=True)
|
||||
torch_z.sum().backward()
|
||||
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
|
||||
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
|
||||
@@ -380,7 +380,7 @@ class TestNN(unittest.TestCase):
|
||||
|
||||
torch_x = torch.tensor(x.numpy(), requires_grad=True)
|
||||
torch_z = torch_layer(torch_x.permute(0,2,3,1)).permute(0,3,1,2)
|
||||
torch_z.sum().backward(retain_graph=True)
|
||||
torch_z.sum().backward()
|
||||
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
|
||||
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
|
||||
@@ -406,7 +406,7 @@ class TestNN(unittest.TestCase):
|
||||
|
||||
torch_x = torch.tensor(x.numpy(), requires_grad=True)
|
||||
torch_z = torch_layer(torch_x)
|
||||
torch_z.sum().backward(retain_graph=True)
|
||||
torch_z.sum().backward()
|
||||
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
|
||||
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
|
||||
@@ -432,7 +432,7 @@ class TestNN(unittest.TestCase):
|
||||
|
||||
torch_x = torch.tensor(x.numpy(), requires_grad=True)
|
||||
torch_z = torch_layer(torch_x)
|
||||
torch_z.sum().backward(retain_graph=True)
|
||||
torch_z.sum().backward()
|
||||
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
|
||||
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
|
||||
@@ -467,7 +467,7 @@ class TestNN(unittest.TestCase):
|
||||
|
||||
torch_x = torch.tensor(x.numpy(), requires_grad=True)
|
||||
torch_z = torch_layer(torch_x)
|
||||
torch_z.sum().backward(retain_graph=True)
|
||||
torch_z.sum().backward()
|
||||
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
|
||||
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
|
||||
|
||||
@@ -63,17 +63,16 @@ class TestTinygrad(unittest.TestCase):
|
||||
np.testing.assert_allclose(x, y, atol=1e-5)
|
||||
|
||||
# A simple test is to check that we can accumulate gradients (run backward twice or more times)
|
||||
# This will only work if retain_graph works.
|
||||
def test_retain_graph(self):
|
||||
def test_accumulate_gradients(self):
|
||||
x = Tensor(x_init, requires_grad=True)
|
||||
W = Tensor(W_init, requires_grad=True)
|
||||
m = Tensor(m_init)
|
||||
out = x.dot(W).relu()
|
||||
out = out.log_softmax()
|
||||
out = out.mul(m).add(m).sum()
|
||||
out.backward(retain_graph=True)
|
||||
out.backward()
|
||||
xgrad,wgrad = x.grad, W.grad
|
||||
out.backward(retain_graph=True)
|
||||
out.backward()
|
||||
xgrad2,wgrad2 = x.grad, W.grad
|
||||
out.backward() # no need to retain again since we will not re-run backward
|
||||
xgrad3,wgrad3 = x.grad, W.grad
|
||||
|
||||
@@ -915,11 +915,10 @@ class Tensor(SimpleMathTrait):
|
||||
# create returned Tensors
|
||||
return [Tensor(u, device=t.device) for t,u in zip(targets, rets[0])]
|
||||
|
||||
def backward(self, gradient:Optional[Tensor]=None, retain_graph:bool=False) -> Tensor:
|
||||
def backward(self, gradient:Optional[Tensor]=None) -> Tensor:
|
||||
"""
|
||||
Propagates the gradient of a tensor backwards through the computation graph.
|
||||
If the 'gradient' argument is not provided, the tensor must be a scalar, and the gradient is implicitly set to 1.0.
|
||||
If 'retain_graph' is false, the graph used to compute the grads will be freed. Otherwise, it will be kept. Keeping it can increase memory usage.
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
|
||||
t.sum().backward()
|
||||
|
||||
Reference in New Issue
Block a user