mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
assert to prepare for grad uop [pr] (#8280)
* assert to prepare for grad uop [pr] * fix test_nn * fix most of test_tensor * few more tests * fix multi * uniform gradient * acc_dtype * any for multi * fix typing * fix assert, CAST_BEFORE_VIEW is still the issue * explict test for CAST_BEFORE_VIEW --------- Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com>
This commit is contained in:
@@ -93,6 +93,12 @@ class TestTensorGradient(unittest.TestCase):
|
||||
dx = z.gradient(x, gradient=dz)[0]
|
||||
self.assertListEqual(dx.tolist(), [2.0, 4.0, 6.0])
|
||||
|
||||
def test_cast_before_view(self):
|
||||
x = Tensor([1.0, 1, 1, 1])
|
||||
x_reshaped = x.reshape(2,2)
|
||||
x_casted = x_reshaped.cast(dtypes.float16)
|
||||
x_casted.mean().gradient(x_reshaped)
|
||||
|
||||
class TestRealizeMeansRealize(unittest.TestCase):
|
||||
def test_randn_realizes(self):
|
||||
x = Tensor.randn(2, 3, 64, 64, requires_grad=True).realize()
|
||||
@@ -104,5 +110,11 @@ class TestRealizeMeansRealize(unittest.TestCase):
|
||||
print(x.lazydata)
|
||||
self.assertEqual(x.lazydata.op, Ops.VIEW)
|
||||
|
||||
# NOTE: even though it doesn't realize, this seems fine
|
||||
def test_uniform_gradient(self):
|
||||
x = Tensor.uniform(16, 3, 3, 3, requires_grad=True).realize()
|
||||
y = x * 2
|
||||
y.sum().gradient(x)[0].realize()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user