From 185a00088212daeaada5d03f533389352b048e7d Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 19 Dec 2025 13:33:59 -0500 Subject: [PATCH] gradient of COPY (#13760) --- test/unit/test_gradient.py | 6 ++++++ tinygrad/gradient.py | 1 + 2 files changed, 7 insertions(+) diff --git a/test/unit/test_gradient.py b/test/unit/test_gradient.py index 08abcf64ff..c91cdbf83c 100644 --- a/test/unit/test_gradient.py +++ b/test/unit/test_gradient.py @@ -110,6 +110,12 @@ class TestTensorGradient(unittest.TestCase): with self.assertRaises(RuntimeError): x.sum().gradient(x) with self.assertRaises(RuntimeError): x.float().sum().gradient(x) + def test_copy_to_device_gradient(self): + t = Tensor([1.0, 2, 3], requires_grad=True).realize() + t.to("CPU:1").square().sum().backward() + self.assertEqual(t.grad.device, t.device) + self.assertListEqual(t.grad.tolist(), [2.0, 4.0, 6.0]) + def test_multiple_backward(self): x = Tensor([3.], requires_grad=True) (x*2)[0].backward() diff --git a/tinygrad/gradient.py b/tinygrad/gradient.py index b4576cdf27..1447c69113 100644 --- a/tinygrad/gradient.py +++ b/tinygrad/gradient.py @@ -39,6 +39,7 @@ pm_gradient = PatternMatcher([ (UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)), (UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.marg)),)), (UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip([i for i,x in enumerate(ret.marg) if x]),)), + (UPat(Ops.COPY, name="ret"), lambda ctx, ret: (ctx.copy_to_device(ret.src[0].device), None)), (UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src), # NOTE: this is only correct when the KERNEL has a single output (UPat(Ops.AFTER), lambda ctx: (ctx, ctx)),