gradient of COPY (#13760)

This commit is contained in:
chenyu
2025-12-19 13:33:59 -05:00
committed by GitHub
parent 57fe4d0a59
commit 185a000882
2 changed files with 7 additions and 0 deletions

View File

@@ -110,6 +110,12 @@ class TestTensorGradient(unittest.TestCase):
with self.assertRaises(RuntimeError): x.sum().gradient(x)
with self.assertRaises(RuntimeError): x.float().sum().gradient(x)
def test_copy_to_device_gradient(self):
t = Tensor([1.0, 2, 3], requires_grad=True).realize()
t.to("CPU:1").square().sum().backward()
self.assertEqual(t.grad.device, t.device)
self.assertListEqual(t.grad.tolist(), [2.0, 4.0, 6.0])
def test_multiple_backward(self):
x = Tensor([3.], requires_grad=True)
(x*2)[0].backward()