mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
incorporate changes
This commit is contained in:
@@ -78,7 +78,7 @@ class TestMNIST(unittest.TestCase):
|
||||
|
||||
# evaluate
|
||||
def numpy_eval():
|
||||
Y_test_preds_out = model.forward(Tensor(X_test.reshape((-1, 28*28))))
|
||||
Y_test_preds_out = model.forward(Tensor(X_test.reshape((-1, 28*28)).astype(np.float32)))
|
||||
Y_test_preds = np.argmax(Y_test_preds_out.data, axis=1)
|
||||
return (Y_test == Y_test_preds).mean()
|
||||
|
||||
@@ -86,5 +86,6 @@ class TestMNIST(unittest.TestCase):
|
||||
print("test set accuracy is %f" % accuracy)
|
||||
assert accuracy > 0.95
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user