mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
transformer >99.98% test accuracy in ~30s (#230)
* transformer * BS might divide len(Y_test) * outoput when accuracy is high * more readeable * fixed loss in serious_mnist for new API
This commit is contained in:
@@ -10,7 +10,7 @@ from tinygrad.tensor import Tensor, GPU
|
||||
from tinygrad.nn import BatchNorm2D
|
||||
from extra.utils import get_parameters
|
||||
from test_mnist import fetch_mnist
|
||||
from extra.training import train, evaluate
|
||||
from extra.training import train, evaluate, sparse_categorical_crossentropy
|
||||
import tinygrad.optim as optim
|
||||
from extra.augment import augment_img
|
||||
GPU = os.getenv("GPU", None) is not None
|
||||
@@ -106,7 +106,7 @@ if __name__ == "__main__":
|
||||
BS = 32
|
||||
|
||||
lmbd = 0.00025
|
||||
lossfn = lambda out,y: out.mul(y).mean() + lmbd*(model.weight1.abs() + model.weight2.abs()).sum()
|
||||
lossfn = lambda out,y: sparse_categorical_crossentropy(out, y) + lmbd*(model.weight1.abs() + model.weight2.abs()).sum()
|
||||
X_train, Y_train, X_test, Y_test = fetch_mnist()
|
||||
steps = len(X_train)//BS
|
||||
np.random.seed(1337)
|
||||
@@ -133,6 +133,6 @@ if __name__ == "__main__":
|
||||
for epoch in range(1,epochs+1):
|
||||
#first epoch without augmentation
|
||||
X_aug = X_train if epoch == 1 else augment_img(X_train)
|
||||
train(model, X_aug, Y_train, optimizer, steps=steps, lossfn=lossfn, gpu=GPU, BS=BS)
|
||||
train(model, X_aug, Y_train, optimizer, steps=steps, lossfn=lossfn, BS=BS)
|
||||
accuracy = evaluate(model, X_test, Y_test, BS=BS)
|
||||
model.save('examples/checkpoint'+str("%.0f" % (accuracy*1.0e6)))
|
||||
|
||||
@@ -30,10 +30,18 @@ if __name__ == "__main__":
|
||||
model = Transformer(10, 6, 2, 128, 4)
|
||||
|
||||
X_train, Y_train, X_test, Y_test = make_dataset()
|
||||
optim = Adam(get_parameters(model), lr=0.001)
|
||||
|
||||
for i in range(5):
|
||||
train(model, X_train, Y_train, optim, 500, BS=32, device=Device.GPU if os.getenv("GPU") else Device.CPU)
|
||||
evaluate(model, X_test, Y_test, num_classes=10)
|
||||
|
||||
|
||||
lr = 0.003
|
||||
for i in range(10):
|
||||
optim = Adam(get_parameters(model), lr=lr)
|
||||
train(model, X_train, Y_train, optim, 50, BS=64)
|
||||
acc, Y_test_preds = evaluate(model, X_test, Y_test, num_classes=10, return_predict=True)
|
||||
lr /= 1.2
|
||||
print(f'reducing lr to {lr:.4f}')
|
||||
if acc > 0.998:
|
||||
wrong=0
|
||||
for k in range(len(Y_test_preds)):
|
||||
if (Y_test_preds[k] != Y_test[k]).any():
|
||||
wrong+=1
|
||||
a,b,c,x = X_test[k,:2], X_test[k,2:4], Y_test[k,-3:], Y_test_preds[k,-3:]
|
||||
print(f'{a[0]}{a[1]} + {b[0]}{b[1]} = {x[0]}{x[1]}{x[2]} (correct: {c[0]}{c[1]}{c[2]})')
|
||||
print(f'Wrong predictions: {wrong}, acc = {acc:.4f}')
|
||||
|
||||
@@ -40,17 +40,17 @@ def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=sparse_categoric
|
||||
accuracies.append(accuracy)
|
||||
t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy))
|
||||
|
||||
def evaluate(model, X_test, Y_test, num_classes=None, BS=128):
|
||||
def evaluate(model, X_test, Y_test, num_classes=None, BS=128, return_predict=False):
|
||||
Tensor.training = False
|
||||
def numpy_eval(num_classes):
|
||||
Y_test_preds_out = np.zeros(list(Y_test.shape)+[num_classes])
|
||||
for i in trange(len(Y_test)//BS, disable=os.getenv('CI') is not None):
|
||||
for i in trange((len(Y_test)-1)//BS+1, disable=os.getenv('CI') is not None):
|
||||
Y_test_preds_out[i*BS:(i+1)*BS] = model.forward(Tensor(X_test[i*BS:(i+1)*BS])).cpu().data
|
||||
Y_test_preds = np.argmax(Y_test_preds_out, axis=-1)
|
||||
return (Y_test == Y_test_preds).mean()
|
||||
return (Y_test == Y_test_preds).mean(), Y_test_preds
|
||||
|
||||
if num_classes is None: num_classes = Y_test.max().astype(int)+1
|
||||
accuracy = numpy_eval(num_classes)
|
||||
print("test set accuracy is %f" % accuracy)
|
||||
return accuracy
|
||||
acc, Y_test_pred = numpy_eval(num_classes)
|
||||
print("test set accuracy is %f" % acc)
|
||||
return (acc, Y_test_pred) if return_predict else acc
|
||||
|
||||
|
||||
Reference in New Issue
Block a user