diff --git a/test/test_mnist.py b/test/test_mnist.py index da12322cc2..ce2a88d123 100644 --- a/test/test_mnist.py +++ b/test/test_mnist.py @@ -78,7 +78,7 @@ class TestMNIST(unittest.TestCase): # evaluate def numpy_eval(): - Y_test_preds_out = model.forward(Tensor(X_test.reshape((-1, 28*28)))) + Y_test_preds_out = model.forward(Tensor(X_test.reshape((-1, 28*28)).astype(np.float32))) Y_test_preds = np.argmax(Y_test_preds_out.data, axis=1) return (Y_test == Y_test_preds).mean() @@ -86,5 +86,6 @@ class TestMNIST(unittest.TestCase): print("test set accuracy is %f" % accuracy) assert accuracy > 0.95 + if __name__ == '__main__': unittest.main()