mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
incorporate changes
This commit is contained in:
@@ -78,7 +78,7 @@ class TestMNIST(unittest.TestCase):
|
|||||||
|
|
||||||
# evaluate
|
# evaluate
|
||||||
def numpy_eval():
|
def numpy_eval():
|
||||||
Y_test_preds_out = model.forward(Tensor(X_test.reshape((-1, 28*28))))
|
Y_test_preds_out = model.forward(Tensor(X_test.reshape((-1, 28*28)).astype(np.float32)))
|
||||||
Y_test_preds = np.argmax(Y_test_preds_out.data, axis=1)
|
Y_test_preds = np.argmax(Y_test_preds_out.data, axis=1)
|
||||||
return (Y_test == Y_test_preds).mean()
|
return (Y_test == Y_test_preds).mean()
|
||||||
|
|
||||||
@@ -86,5 +86,6 @@ class TestMNIST(unittest.TestCase):
|
|||||||
print("test set accuracy is %f" % accuracy)
|
print("test set accuracy is %f" % accuracy)
|
||||||
assert accuracy > 0.95
|
assert accuracy > 0.95
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user