assert to prepare for grad uop [pr] (#8280)

* assert to prepare for grad uop [pr]

* fix test_nn

* fix most of test_tensor

* few more tests

* fix multi

* uniform gradient

* acc_dtype

* any for multi

* fix typing

* fix assert, CAST_BEFORE_VIEW is still the issue

* explict test for CAST_BEFORE_VIEW

---------

Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com>
This commit is contained in:
George Hotz
2025-01-14 13:26:56 -08:00
committed by GitHub
parent fdd46c9f28
commit c85737c200
5 changed files with 20 additions and 2 deletions

View File

@@ -93,6 +93,12 @@ class TestTensorGradient(unittest.TestCase):
dx = z.gradient(x, gradient=dz)[0]
self.assertListEqual(dx.tolist(), [2.0, 4.0, 6.0])
def test_cast_before_view(self):
x = Tensor([1.0, 1, 1, 1])
x_reshaped = x.reshape(2,2)
x_casted = x_reshaped.cast(dtypes.float16)
x_casted.mean().gradient(x_reshaped)
class TestRealizeMeansRealize(unittest.TestCase):
def test_randn_realizes(self):
x = Tensor.randn(2, 3, 64, 64, requires_grad=True).realize()
@@ -104,5 +110,11 @@ class TestRealizeMeansRealize(unittest.TestCase):
print(x.lazydata)
self.assertEqual(x.lazydata.op, Ops.VIEW)
# NOTE: even though it doesn't realize, this seems fine
def test_uniform_gradient(self):
x = Tensor.uniform(16, 3, 3, 3, requires_grad=True).realize()
y = x * 2
y.sum().gradient(x)[0].realize()
if __name__ == '__main__':
unittest.main()