incorporate changes

This commit is contained in:
Adrian Garcia Badaracco
2020-10-21 13:21:44 -05:00
parent 58b4f191a4
commit 9a8be135a7

View File

@@ -78,7 +78,7 @@ class TestMNIST(unittest.TestCase):
# evaluate # evaluate
def numpy_eval(): def numpy_eval():
Y_test_preds_out = model.forward(Tensor(X_test.reshape((-1, 28*28)))) Y_test_preds_out = model.forward(Tensor(X_test.reshape((-1, 28*28)).astype(np.float32)))
Y_test_preds = np.argmax(Y_test_preds_out.data, axis=1) Y_test_preds = np.argmax(Y_test_preds_out.data, axis=1)
return (Y_test == Y_test_preds).mean() return (Y_test == Y_test_preds).mean()
@@ -86,5 +86,6 @@ class TestMNIST(unittest.TestCase):
print("test set accuracy is %f" % accuracy) print("test set accuracy is %f" % accuracy)
assert accuracy > 0.95 assert accuracy > 0.95
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()