mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
Test for optim assertion (#4558)
* add test for assertion * whitespace * restore state --------- Co-authored-by: Thomas Ziereis <thomas.ziereis@web.de>
This commit is contained in:
@@ -119,5 +119,17 @@ class TestOptim(unittest.TestCase):
|
||||
self._test_adamw(1, {'lr': 1e10}, 1e-4, 1e-4)
|
||||
dtypes.default_float = old_default_float
|
||||
|
||||
def test_assert_tensor_train(self):
|
||||
t = Tensor.ones((1,1), requires_grad=True)
|
||||
optimizer = Adam([t])
|
||||
optimizer.zero_grad()
|
||||
old_state = Tensor.training
|
||||
t.sum().backward()
|
||||
Tensor.training = False
|
||||
self.assertRaises(AssertionError, optimizer.step)
|
||||
Tensor.training = True
|
||||
optimizer.step()
|
||||
Tensor.training = old_state
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user