mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
transformer >99.98% test accuracy in ~30s (#230)
* transformer * BS might divide len(Y_test) * outoput when accuracy is high * more readeable * fixed loss in serious_mnist for new API
This commit is contained in:
@@ -40,17 +40,17 @@ def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=sparse_categoric
|
||||
accuracies.append(accuracy)
|
||||
t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy))
|
||||
|
||||
def evaluate(model, X_test, Y_test, num_classes=None, BS=128):
|
||||
def evaluate(model, X_test, Y_test, num_classes=None, BS=128, return_predict=False):
|
||||
Tensor.training = False
|
||||
def numpy_eval(num_classes):
|
||||
Y_test_preds_out = np.zeros(list(Y_test.shape)+[num_classes])
|
||||
for i in trange(len(Y_test)//BS, disable=os.getenv('CI') is not None):
|
||||
for i in trange((len(Y_test)-1)//BS+1, disable=os.getenv('CI') is not None):
|
||||
Y_test_preds_out[i*BS:(i+1)*BS] = model.forward(Tensor(X_test[i*BS:(i+1)*BS])).cpu().data
|
||||
Y_test_preds = np.argmax(Y_test_preds_out, axis=-1)
|
||||
return (Y_test == Y_test_preds).mean()
|
||||
return (Y_test == Y_test_preds).mean(), Y_test_preds
|
||||
|
||||
if num_classes is None: num_classes = Y_test.max().astype(int)+1
|
||||
accuracy = numpy_eval(num_classes)
|
||||
print("test set accuracy is %f" % accuracy)
|
||||
return accuracy
|
||||
acc, Y_test_pred = numpy_eval(num_classes)
|
||||
print("test set accuracy is %f" % acc)
|
||||
return (acc, Y_test_pred) if return_predict else acc
|
||||
|
||||
|
||||
Reference in New Issue
Block a user