tensor inference (#6156)

* tensor inference

* test is even better name
This commit is contained in:
George Hotz
2024-08-18 00:19:28 -07:00
committed by GitHub
parent 9db2d0d5c6
commit 17a043edad
3 changed files with 12 additions and 11 deletions

View File

@@ -642,11 +642,11 @@ class TestTrainMode(unittest.TestCase):
assert not Tensor.training
class TestInferenceMode(unittest.TestCase):
def test_inference_mode(self):
def test_inference(self):
x = Tensor(x_init, requires_grad=True)
m = Tensor(m_init, requires_grad=True)
W = Tensor(W_init, requires_grad=True)
with Tensor.inference_mode():
with Tensor.test():
tmp = x.mul(m)
mm = tmp.matmul(W)
out = mm.relu()
@@ -663,7 +663,7 @@ class TestInferenceMode(unittest.TestCase):
x = Tensor(x_init, requires_grad=True)
m = Tensor(m_init, requires_grad=True)
W = Tensor(W_init, requires_grad=True)
@Tensor.inference_mode()
@Tensor.test()
def f(x, m, W):
tmp = x.mul(m)
mm = tmp.matmul(W)