mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
remove retain_graph in Tensor.backward [pr] (#8835)
not used. gradient accumulation works directly
This commit is contained in:
@@ -328,7 +328,7 @@ class TestNN(unittest.TestCase):
|
|||||||
|
|
||||||
torch_x = torch.tensor(x.numpy(), requires_grad=True)
|
torch_x = torch.tensor(x.numpy(), requires_grad=True)
|
||||||
torch_z = torch_layer(torch_x)
|
torch_z = torch_layer(torch_x)
|
||||||
torch_z.sum().backward(retain_graph=True)
|
torch_z.sum().backward()
|
||||||
|
|
||||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
|
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
|
||||||
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
|
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
|
||||||
@@ -354,7 +354,7 @@ class TestNN(unittest.TestCase):
|
|||||||
|
|
||||||
torch_x = torch.tensor(x.numpy(), requires_grad=True)
|
torch_x = torch.tensor(x.numpy(), requires_grad=True)
|
||||||
torch_z = torch_layer(torch_x)
|
torch_z = torch_layer(torch_x)
|
||||||
torch_z.sum().backward(retain_graph=True)
|
torch_z.sum().backward()
|
||||||
|
|
||||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
|
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
|
||||||
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
|
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
|
||||||
@@ -380,7 +380,7 @@ class TestNN(unittest.TestCase):
|
|||||||
|
|
||||||
torch_x = torch.tensor(x.numpy(), requires_grad=True)
|
torch_x = torch.tensor(x.numpy(), requires_grad=True)
|
||||||
torch_z = torch_layer(torch_x.permute(0,2,3,1)).permute(0,3,1,2)
|
torch_z = torch_layer(torch_x.permute(0,2,3,1)).permute(0,3,1,2)
|
||||||
torch_z.sum().backward(retain_graph=True)
|
torch_z.sum().backward()
|
||||||
|
|
||||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
|
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
|
||||||
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
|
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=5e-4, rtol=5e-4)
|
||||||
@@ -406,7 +406,7 @@ class TestNN(unittest.TestCase):
|
|||||||
|
|
||||||
torch_x = torch.tensor(x.numpy(), requires_grad=True)
|
torch_x = torch.tensor(x.numpy(), requires_grad=True)
|
||||||
torch_z = torch_layer(torch_x)
|
torch_z = torch_layer(torch_x)
|
||||||
torch_z.sum().backward(retain_graph=True)
|
torch_z.sum().backward()
|
||||||
|
|
||||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
|
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
|
||||||
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
|
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
|
||||||
@@ -432,7 +432,7 @@ class TestNN(unittest.TestCase):
|
|||||||
|
|
||||||
torch_x = torch.tensor(x.numpy(), requires_grad=True)
|
torch_x = torch.tensor(x.numpy(), requires_grad=True)
|
||||||
torch_z = torch_layer(torch_x)
|
torch_z = torch_layer(torch_x)
|
||||||
torch_z.sum().backward(retain_graph=True)
|
torch_z.sum().backward()
|
||||||
|
|
||||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
|
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
|
||||||
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
|
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
|
||||||
@@ -467,7 +467,7 @@ class TestNN(unittest.TestCase):
|
|||||||
|
|
||||||
torch_x = torch.tensor(x.numpy(), requires_grad=True)
|
torch_x = torch.tensor(x.numpy(), requires_grad=True)
|
||||||
torch_z = torch_layer(torch_x)
|
torch_z = torch_layer(torch_x)
|
||||||
torch_z.sum().backward(retain_graph=True)
|
torch_z.sum().backward()
|
||||||
|
|
||||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
|
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-6, rtol=5e-6)
|
||||||
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
|
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
|
||||||
|
|||||||
@@ -63,17 +63,16 @@ class TestTinygrad(unittest.TestCase):
|
|||||||
np.testing.assert_allclose(x, y, atol=1e-5)
|
np.testing.assert_allclose(x, y, atol=1e-5)
|
||||||
|
|
||||||
# A simple test is to check that we can accumulate gradients (run backward twice or more times)
|
# A simple test is to check that we can accumulate gradients (run backward twice or more times)
|
||||||
# This will only work if retain_graph works.
|
def test_accumulate_gradients(self):
|
||||||
def test_retain_graph(self):
|
|
||||||
x = Tensor(x_init, requires_grad=True)
|
x = Tensor(x_init, requires_grad=True)
|
||||||
W = Tensor(W_init, requires_grad=True)
|
W = Tensor(W_init, requires_grad=True)
|
||||||
m = Tensor(m_init)
|
m = Tensor(m_init)
|
||||||
out = x.dot(W).relu()
|
out = x.dot(W).relu()
|
||||||
out = out.log_softmax()
|
out = out.log_softmax()
|
||||||
out = out.mul(m).add(m).sum()
|
out = out.mul(m).add(m).sum()
|
||||||
out.backward(retain_graph=True)
|
out.backward()
|
||||||
xgrad,wgrad = x.grad, W.grad
|
xgrad,wgrad = x.grad, W.grad
|
||||||
out.backward(retain_graph=True)
|
out.backward()
|
||||||
xgrad2,wgrad2 = x.grad, W.grad
|
xgrad2,wgrad2 = x.grad, W.grad
|
||||||
out.backward() # no need to retain again since we will not re-run backward
|
out.backward() # no need to retain again since we will not re-run backward
|
||||||
xgrad3,wgrad3 = x.grad, W.grad
|
xgrad3,wgrad3 = x.grad, W.grad
|
||||||
|
|||||||
@@ -915,11 +915,10 @@ class Tensor(SimpleMathTrait):
|
|||||||
# create returned Tensors
|
# create returned Tensors
|
||||||
return [Tensor(u, device=t.device) for t,u in zip(targets, rets[0])]
|
return [Tensor(u, device=t.device) for t,u in zip(targets, rets[0])]
|
||||||
|
|
||||||
def backward(self, gradient:Optional[Tensor]=None, retain_graph:bool=False) -> Tensor:
|
def backward(self, gradient:Optional[Tensor]=None) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Propagates the gradient of a tensor backwards through the computation graph.
|
Propagates the gradient of a tensor backwards through the computation graph.
|
||||||
If the 'gradient' argument is not provided, the tensor must be a scalar, and the gradient is implicitly set to 1.0.
|
If the 'gradient' argument is not provided, the tensor must be a scalar, and the gradient is implicitly set to 1.0.
|
||||||
If 'retain_graph' is false, the graph used to compute the grads will be freed. Otherwise, it will be kept. Keeping it can increase memory usage.
|
|
||||||
```python exec="true" source="above" session="tensor" result="python"
|
```python exec="true" source="above" session="tensor" result="python"
|
||||||
t = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
|
t = Tensor([1.0, 2.0, 3.0, 4.0], requires_grad=True)
|
||||||
t.sum().backward()
|
t.sum().backward()
|
||||||
|
|||||||
Reference in New Issue
Block a user