mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
gradient of COPY (#13760)
This commit is contained in:
@@ -110,6 +110,12 @@ class TestTensorGradient(unittest.TestCase):
|
||||
with self.assertRaises(RuntimeError): x.sum().gradient(x)
|
||||
with self.assertRaises(RuntimeError): x.float().sum().gradient(x)
|
||||
|
||||
def test_copy_to_device_gradient(self):
|
||||
t = Tensor([1.0, 2, 3], requires_grad=True).realize()
|
||||
t.to("CPU:1").square().sum().backward()
|
||||
self.assertEqual(t.grad.device, t.device)
|
||||
self.assertListEqual(t.grad.tolist(), [2.0, 4.0, 6.0])
|
||||
|
||||
def test_multiple_backward(self):
|
||||
x = Tensor([3.], requires_grad=True)
|
||||
(x*2)[0].backward()
|
||||
|
||||
Reference in New Issue
Block a user