Test for optim assertion (#4558)

* add test for assertion

* whitespace

* restore state

---------

Co-authored-by: Thomas Ziereis <thomas.ziereis@web.de>
This commit is contained in:
ziereis
2024-05-12 23:21:28 +02:00
committed by GitHub
parent d7670f8141
commit f53a23d21e

View File

@@ -119,5 +119,17 @@ class TestOptim(unittest.TestCase):
self._test_adamw(1, {'lr': 1e10}, 1e-4, 1e-4)
dtypes.default_float = old_default_float
def test_assert_tensor_train(self):
t = Tensor.ones((1,1), requires_grad=True)
optimizer = Adam([t])
optimizer.zero_grad()
old_state = Tensor.training
t.sum().backward()
Tensor.training = False
self.assertRaises(AssertionError, optimizer.step)
Tensor.training = True
optimizer.step()
Tensor.training = old_state
if __name__ == '__main__':
unittest.main()
unittest.main()