remove retain_graph in Tensor.backward [pr] (#8835)

not used. gradient accumulation works directly
This commit is contained in:
chenyu
2025-01-31 13:41:26 -05:00
committed by GitHub
parent 0a59db936a
commit 1f730ae8f8
3 changed files with 10 additions and 12 deletions

View File

@@ -328,7 +328,7 @@ class TestNN(unittest.TestCase):
torch_x = torch.tensor(x.numpy(), requires_grad=True) torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x) torch_z = torch_layer(torch_x)
torch_z.sum().backward(retain_graph=True) torch_z.sum().backward()
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6) np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4) np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
@@ -354,7 +354,7 @@ class TestNN(unittest.TestCase):
torch_x = torch.tensor(x.numpy(), requires_grad=True) torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x) torch_z = torch_layer(torch_x)
torch_z.sum().backward(retain_graph=True) torch_z.sum().backward()
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6) np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4) np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
@@ -380,7 +380,7 @@ class TestNN(unittest.TestCase):
torch_x = torch.tensor(x.numpy(), requires_grad=True) torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x.permute(0,2,3,1)).permute(0,3,1,2) torch_z = torch_layer(torch_x.permute(0,2,3,1)).permute(0,3,1,2)
torch_z.sum().backward(retain_graph=True) torch_z.sum().backward()
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6) np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4) np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
@@ -406,7 +406,7 @@ class TestNN(unittest.TestCase):
torch_x = torch.tensor(x.numpy(), requires_grad=True) torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x) torch_z = torch_layer(torch_x)
torch_z.sum().backward(retain_graph=True) torch_z.sum().backward()
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6) np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3) np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
@@ -432,7 +432,7 @@ class TestNN(unittest.TestCase):
torch_x = torch.tensor(x.numpy(), requires_grad=True) torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x) torch_z = torch_layer(torch_x)
torch_z.sum().backward(retain_graph=True) torch_z.sum().backward()
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6) np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3) np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
@@ -467,7 +467,7 @@ class TestNN(unittest.TestCase):
torch_x = torch.tensor(x.numpy(), requires_grad=True) torch_x = torch.tensor(x.numpy(), requires_grad=True)
torch_z = torch_layer(torch_x) torch_z = torch_layer(torch_x)
torch_z.sum().backward(retain_graph=True) torch_z.sum().backward()
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6) np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3) np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)

View File

@@ -63,17 +63,16 @@ class TestTinygrad(unittest.TestCase):
np.testing.assert_allclose(x, y, atol=1e-5) np.testing.assert_allclose(x, y, atol=1e-5)
# A simple test is to check that we can accumulate gradients (run backward twice or more times) # A simple test is to check that we can accumulate gradients (run backward twice or more times)
# This will only work if retain_graph works. def test_accumulate_gradients(self):
def test_retain_graph(self):
x = Tensor(x_init, requires_grad=True) x = Tensor(x_init, requires_grad=True)
W = Tensor(W_init, requires_grad=True) W = Tensor(W_init, requires_grad=True)
m = Tensor(m_init) m = Tensor(m_init)
out = x.dot(W).relu() out = x.dot(W).relu()
out = out.log_softmax() out = out.log_softmax()
out = out.mul(m).add(m).sum() out = out.mul(m).add(m).sum()
out.backward(retain_graph=True) out.backward()
xgrad,wgrad = x.grad, W.grad xgrad,wgrad = x.grad, W.grad
out.backward(retain_graph=True) out.backward()
xgrad2,wgrad2 = x.grad, W.grad xgrad2,wgrad2 = x.grad, W.grad
out.backward() # no need to retain again since we will not re-run backward out.backward() # no need to retain again since we will not re-run backward
xgrad3,wgrad3 = x.grad, W.grad xgrad3,wgrad3 = x.grad, W.grad

View File

@@ -915,11 +915,10 @@ class Tensor(SimpleMathTrait):
# create returned Tensors # create returned Tensors
return [Tensor(u, device=t.device) for t,u in zip(targets, rets[0])] return [Tensor(u, device=t.device) for t,u in zip(targets, rets[0])]
def backward(self, gradient:Optional[Tensor]=None, retain_graph:bool=False) -> Tensor: def backward(self, gradient:Optional[Tensor]=None) -> Tensor:
""" """
Propagates the gradient of a tensor backwards through the computation graph. Propagates the gradient of a tensor backwards through the computation graph.
If the 'gradient' argument is not provided, the tensor must be a scalar, and the gradient is implicitly set to 1.0. If the 'gradient' argument is not provided, the tensor must be a scalar, and the gradient is implicitly set to 1.0.
If 'retain_graph' is false, the graph used to compute the grads will be freed. Otherwise, it will be kept. Keeping it can increase memory usage.
```python exec="true" source="above" session="tensor" result="python" ```python exec="true" source="above" session="tensor" result="python"
t = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True) t = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
t.sum().backward() t.sum().backward()