incorporate changes

This commit is contained in:
Adrian Garcia Badaracco
2020-10-21 13:21:44 -05:00
parent 58b4f191a4
commit 9a8be135a7

View File

@@ -78,7 +78,7 @@ class TestMNIST(unittest.TestCase):
# evaluate
def numpy_eval():
Y_test_preds_out = model.forward(Tensor(X_test.reshape((-1, 28*28))))
Y_test_preds_out = model.forward(Tensor(X_test.reshape((-1, 28*28)).astype(np.float32)))
Y_test_preds = np.argmax(Y_test_preds_out.data, axis=1)
return (Y_test == Y_test_preds).mean()
@@ -86,5 +86,6 @@ class TestMNIST(unittest.TestCase):
print("test set accuracy is %f" % accuracy)
assert accuracy > 0.95
if __name__ == '__main__':
unittest.main()