diff --git a/examples/beautiful_mnist.py b/examples/beautiful_mnist.py index 88f3d0d986..7f003f6f8c 100644 --- a/examples/beautiful_mnist.py +++ b/examples/beautiful_mnist.py @@ -24,16 +24,17 @@ if __name__ == "__main__": opt = nn.optim.Adam(nn.state.get_parameters(model)) @TinyJit + @Tensor.train() def train_step() -> Tensor: - with Tensor.train(): - opt.zero_grad() - samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0]) - # TODO: this "gather" of samples is very slow. will be under 5s when this is fixed - loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward() - opt.step() - return loss + opt.zero_grad() + samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0]) + # TODO: this "gather" of samples is very slow. will be under 5s when this is fixed + loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward() + opt.step() + return loss @TinyJit + @Tensor.test() def get_test_acc() -> Tensor: return (model(X_test).argmax(axis=1) == Y_test).mean()*100 test_acc = float('nan') diff --git a/test/test_tensor.py b/test/test_tensor.py index 4bfa697248..af27b0aa82 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -642,11 +642,11 @@ class TestTrainMode(unittest.TestCase): assert not Tensor.training class TestInferenceMode(unittest.TestCase): - def test_inference_mode(self): + def test_inference(self): x = Tensor(x_init, requires_grad=True) m = Tensor(m_init, requires_grad=True) W = Tensor(W_init, requires_grad=True) - with Tensor.inference_mode(): + with Tensor.test(): tmp = x.mul(m) mm = tmp.matmul(W) out = mm.relu() @@ -663,7 +663,7 @@ class TestInferenceMode(unittest.TestCase): x = Tensor(x_init, requires_grad=True) m = Tensor(m_init, requires_grad=True) W = Tensor(W_init, requires_grad=True) - @Tensor.inference_mode() + @Tensor.test() def f(x, m, W): tmp = x.mul(m) mm = tmp.matmul(W) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 59394ca655..f2b3c29b8d 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -157,7 +157,7 @@ class Tensor: def __enter__(self): self.prev, Tensor.training = Tensor.training, self.mode def __exit__(self, exc_type, exc_value, traceback): Tensor.training = self.prev - class inference_mode(ContextDecorator): + class test(ContextDecorator): def __init__(self, mode:bool = True): self.mode = mode def __enter__(self): self.prev, Tensor.no_grad = Tensor.no_grad, self.mode def __exit__(self, exc_type, exc_value, traceback): Tensor.no_grad = self.prev