gradient of COPY (#13760)

This commit is contained in:
chenyu
2025-12-19 13:33:59 -05:00
committed by GitHub
parent 57fe4d0a59
commit 185a000882
2 changed files with 7 additions and 0 deletions

View File

@@ -110,6 +110,12 @@ class TestTensorGradient(unittest.TestCase):
with self.assertRaises(RuntimeError): x.sum().gradient(x)
with self.assertRaises(RuntimeError): x.float().sum().gradient(x)
def test_copy_to_device_gradient(self):
t = Tensor([1.0, 2, 3], requires_grad=True).realize()
t.to("CPU:1").square().sum().backward()
self.assertEqual(t.grad.device, t.device)
self.assertListEqual(t.grad.tolist(), [2.0, 4.0, 6.0])
def test_multiple_backward(self):
x = Tensor([3.], requires_grad=True)
(x*2)[0].backward()

View File

@@ -39,6 +39,7 @@ pm_gradient = PatternMatcher([
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
(UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.marg)),)),
(UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip([i for i,x in enumerate(ret.marg) if x]),)),
(UPat(Ops.COPY, name="ret"), lambda ctx, ret: (ctx.copy_to_device(ret.src[0].device), None)),
(UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),
# NOTE: this is only correct when the KERNEL has a single output
(UPat(Ops.AFTER), lambda ctx: (ctx, ctx)),