Files
tinygrad/test/unit/test_gradient.py
chenyu bf33c5f796 remove gradient materialize_grads (#15367)
effectively default to True

and removed *0 hack in Tensor.copysign. now dy/dx=0 if y does not depend on x

remove
2026-03-19 23:36:03 -04:00

91 lines
3.1 KiB
Python

import unittest
import numpy as np
from tinygrad import Tensor
from tinygrad.dtype import dtypes
class TestTensorGradient(unittest.TestCase):
def test_example(self):
x = Tensor.eye(3)
y = Tensor([[2.0,0,-2.0]])
z = y.matmul(x).sum()
dx, dy = z.gradient(x, y)
self.assertListEqual(dx.tolist(), [[2.0, 2.0, 2.0], [0.0, 0.0, 0.0], [-2.0, -2.0, -2.0]])
self.assertListEqual(dy.tolist(), [[1.0, 1.0, 1.0]])
def test_zero_if_not_used(self):
x = Tensor([1.0, 2.0, 3.0])
w = Tensor.randn((3,))
self.assertListEqual(x.sum().gradient(w)[0].tolist(), [0.0, 0.0, 0.0])
def test_with_custom_gradient(self):
x = Tensor([1.0, 2.0, 3.0])
z = (x * x).sum()
dx = z.gradient(x, gradient=Tensor([3.0]))[0]
self.assertListEqual(dx.tolist(), [6.0, 12.0, 18.0])
def test_broadcast_gradient(self):
x = Tensor([[1.0], [2.0], [3.0]])
y = Tensor([[10.0, 20.0, 30.0, 40.0]])
z = (x + y).sum()
dx, dy = z.gradient(x, y)
self.assertListEqual(dx.tolist(), [[4.0], [4.0], [4.0]])
self.assertListEqual(dy.tolist(), [[3.0, 3.0, 3.0, 3.0]])
def test_non_scalar_output(self):
x = Tensor([1.0, 2.0, 3.0])
z = x * x
with self.assertRaises(AssertionError): z.gradient(x)
dz = Tensor([1.0, 1.0, 1.0])
dx = z.gradient(x, gradient=dz)[0]
self.assertListEqual(dx.tolist(), [2.0, 4.0, 6.0])
def test_cast_before_view(self):
x = Tensor([1.0, 1, 1, 1])
x_reshaped = x.reshape(2,2)
x_casted = x_reshaped.cast(dtypes.float16)
x_casted.mean().gradient(x_reshaped)
def test_non_float_tensor_raise(self):
x = Tensor([1, 2, 3])
with self.assertRaises(RuntimeError): x.sum().gradient(x)
with self.assertRaises(RuntimeError): x.float().sum().gradient(x)
def test_copy_to_device_gradient(self):
t = Tensor([1.0, 2, 3], requires_grad=True).realize()
t.to("CPU:1").square().sum().backward()
self.assertEqual(t.grad.device, t.device)
self.assertListEqual(t.grad.tolist(), [2.0, 4.0, 6.0])
def test_multiple_backward(self):
x = Tensor([3.], requires_grad=True)
(x*2)[0].backward()
np.testing.assert_allclose(x.grad.numpy(), [2.0])
old_grad = x.grad
(x*3)[0].backward()
np.testing.assert_allclose(x.grad.numpy(), [2.0+3.0])
self.assertIs(x.grad, old_grad)
(x*x)[0].backward()
np.testing.assert_allclose(x.grad.numpy(), [2.0+3.0+2*3.0])
self.assertIs(x.grad, old_grad)
def test_gradient_through_chained_unrealized_setitem(self):
g1 = Tensor.zeros(4).contiguous()
g1[2] = Tensor(1.0)
g2 = Tensor.zeros(5, 4).contiguous()
g2[0] = g1
x = Tensor.randn(4, 4)
np.testing.assert_allclose(x.pad(((1,0),(0,0))).gradient(x, gradient=g2)[0].numpy(), np.zeros((4, 4)))
class TestViewGradient(unittest.TestCase):
def test_expand(self):
x = Tensor.randn(5,2)
a = Tensor([3.], requires_grad=True)
aex = a.expand(10)
(aex.reshape(5,2) * x).sum().backward()
np.testing.assert_allclose(aex.grad.numpy(), x.reshape(10).numpy())
with self.assertRaises(AssertionError):
np.testing.assert_allclose(aex.grad.numpy(), a.grad.expand(10).numpy())
if __name__ == '__main__':
unittest.main()