mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
add inference_mode context manager with decorator support (#3621)
* add inference_mode context manager with decorator support * change val to mode for train and inference_mode * fix wrong rename
This commit is contained in:
@@ -495,5 +495,50 @@ class TestTensorCreationDevice(unittest.TestCase):
|
||||
x = y.one_hot(10)
|
||||
x.realize()
|
||||
|
||||
class TestTrainMode(unittest.TestCase):
|
||||
def test_train_mode(self):
|
||||
assert not Tensor.training
|
||||
@Tensor.train()
|
||||
def f():
|
||||
assert Tensor.training
|
||||
f()
|
||||
assert not Tensor.training
|
||||
|
||||
class TestInferenceMode(unittest.TestCase):
|
||||
def test_inference_mode(self):
|
||||
x = Tensor(x_init, requires_grad=True)
|
||||
m = Tensor(m_init, requires_grad=True)
|
||||
W = Tensor(W_init, requires_grad=True)
|
||||
with Tensor.inference_mode():
|
||||
tmp = x.mul(m)
|
||||
mm = tmp.matmul(W)
|
||||
out = mm.relu()
|
||||
out = out.sum()
|
||||
out.backward()
|
||||
assert x.grad is None
|
||||
assert m.grad is None
|
||||
assert tmp.grad is None
|
||||
assert mm.grad is None
|
||||
assert W.grad is None
|
||||
assert W.requires_grad
|
||||
|
||||
def test_no_grad_mode_context_manager(self):
|
||||
x = Tensor(x_init, requires_grad=True)
|
||||
m = Tensor(m_init, requires_grad=True)
|
||||
W = Tensor(W_init, requires_grad=True)
|
||||
@Tensor.inference_mode()
|
||||
def f(x, m, W):
|
||||
tmp = x.mul(m)
|
||||
mm = tmp.matmul(W)
|
||||
out = mm.relu()
|
||||
out = out.sum()
|
||||
out.backward()
|
||||
assert x.grad is None
|
||||
assert m.grad is None
|
||||
assert tmp.grad is None
|
||||
assert mm.grad is None
|
||||
assert W.grad is None
|
||||
f(x, m, W)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user