mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
gradient of COPY (#13760)
This commit is contained in:
@@ -110,6 +110,12 @@ class TestTensorGradient(unittest.TestCase):
|
||||
with self.assertRaises(RuntimeError): x.sum().gradient(x)
|
||||
with self.assertRaises(RuntimeError): x.float().sum().gradient(x)
|
||||
|
||||
def test_copy_to_device_gradient(self):
|
||||
t = Tensor([1.0, 2, 3], requires_grad=True).realize()
|
||||
t.to("CPU:1").square().sum().backward()
|
||||
self.assertEqual(t.grad.device, t.device)
|
||||
self.assertListEqual(t.grad.tolist(), [2.0, 4.0, 6.0])
|
||||
|
||||
def test_multiple_backward(self):
|
||||
x = Tensor([3.], requires_grad=True)
|
||||
(x*2)[0].backward()
|
||||
|
||||
@@ -39,6 +39,7 @@ pm_gradient = PatternMatcher([
|
||||
(UPat(Ops.SHRINK, name="ret"), lambda ctx, ret: (ctx.pad(tuple([(p[0], s-p[1]) for s,p in zip(ret.src[0].shape, ret.marg)])), None, None)),
|
||||
(UPat(Ops.PERMUTE, name="ret"), lambda ctx, ret: (ctx.permute(argsort(ret.marg)),)),
|
||||
(UPat(Ops.FLIP, name="ret"), lambda ctx, ret: (ctx.flip([i for i,x in enumerate(ret.marg) if x]),)),
|
||||
(UPat(Ops.COPY, name="ret"), lambda ctx, ret: (ctx.copy_to_device(ret.src[0].device), None)),
|
||||
(UPat(Ops.MULTI, name="ret"), lambda ctx, ret: ctx.shard(ret.device, ret.axis).src),
|
||||
# NOTE: this is only correct when the KERNEL has a single output
|
||||
(UPat(Ops.AFTER), lambda ctx: (ctx, ctx)),
|
||||
|
||||
Reference in New Issue
Block a user