add inference_mode context manager with decorator support (#3621)

* add inference_mode context manager with decorator support

* change val to mode for train and inference_mode

* fix wrong rename
This commit is contained in:
Maximilian Wolf
2024-03-09 17:38:26 +01:00
committed by GitHub
parent b5cbf1792a
commit 8ae85b2cf5
2 changed files with 53 additions and 3 deletions

View File

@@ -495,5 +495,50 @@ class TestTensorCreationDevice(unittest.TestCase):
x = y.one_hot(10)
x.realize()
class TestTrainMode(unittest.TestCase):
def test_train_mode(self):
assert not Tensor.training
@Tensor.train()
def f():
assert Tensor.training
f()
assert not Tensor.training
class TestInferenceMode(unittest.TestCase):
def test_inference_mode(self):
x = Tensor(x_init, requires_grad=True)
m = Tensor(m_init, requires_grad=True)
W = Tensor(W_init, requires_grad=True)
with Tensor.inference_mode():
tmp = x.mul(m)
mm = tmp.matmul(W)
out = mm.relu()
out = out.sum()
out.backward()
assert x.grad is None
assert m.grad is None
assert tmp.grad is None
assert mm.grad is None
assert W.grad is None
assert W.requires_grad
def test_no_grad_mode_context_manager(self):
x = Tensor(x_init, requires_grad=True)
m = Tensor(m_init, requires_grad=True)
W = Tensor(W_init, requires_grad=True)
@Tensor.inference_mode()
def f(x, m, W):
tmp = x.mul(m)
mm = tmp.matmul(W)
out = mm.relu()
out = out.sum()
out.backward()
assert x.grad is None
assert m.grad is None
assert tmp.grad is None
assert mm.grad is None
assert W.grad is None
f(x, m, W)
if __name__ == '__main__':
unittest.main()