mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
@@ -642,11 +642,11 @@ class TestTrainMode(unittest.TestCase):
|
||||
assert not Tensor.training
|
||||
|
||||
class TestInferenceMode(unittest.TestCase):
|
||||
def test_inference_mode(self):
|
||||
def test_inference(self):
|
||||
x = Tensor(x_init, requires_grad=True)
|
||||
m = Tensor(m_init, requires_grad=True)
|
||||
W = Tensor(W_init, requires_grad=True)
|
||||
with Tensor.inference_mode():
|
||||
with Tensor.test():
|
||||
tmp = x.mul(m)
|
||||
mm = tmp.matmul(W)
|
||||
out = mm.relu()
|
||||
@@ -663,7 +663,7 @@ class TestInferenceMode(unittest.TestCase):
|
||||
x = Tensor(x_init, requires_grad=True)
|
||||
m = Tensor(m_init, requires_grad=True)
|
||||
W = Tensor(W_init, requires_grad=True)
|
||||
@Tensor.inference_mode()
|
||||
@Tensor.test()
|
||||
def f(x, m, W):
|
||||
tmp = x.mul(m)
|
||||
mm = tmp.matmul(W)
|
||||
|
||||
Reference in New Issue
Block a user