hotfix: examples/transformer.py

This commit is contained in:
George Hotz
2024-01-09 19:25:37 -08:00
parent 145718a90f
commit ae83733431
2 changed files with 6 additions and 4 deletions

View File

@@ -28,7 +28,8 @@ if __name__ == "__main__":
lr = 0.003
for i in range(10):
optim = Adam(get_parameters(model), lr=lr)
train(model, X_train, Y_train, optim, 50, BS=64)
# TODO: BUG! why doesn't the JIT work here?
train(model, X_train, Y_train, optim, 50, BS=64, allow_jit=False)
acc, Y_test_preds = evaluate(model, X_test, Y_test, num_classes=10, return_predict=True)
lr /= 1.2
print(f'reducing lr to {lr:.4f}')
@@ -37,6 +38,6 @@ if __name__ == "__main__":
for k in range(len(Y_test_preds)):
if (Y_test_preds[k] != Y_test[k]).any():
wrong+=1
a,b,c,x = X_test[k,:2], X_test[k,2:4], Y_test[k,-3:], Y_test_preds[k,-3:]
a,b,c,x = X_test[k,:2].astype(np.int32), X_test[k,2:4].astype(np.int32), Y_test[k,-3:].astype(np.int32), Y_test_preds[k,-3:].astype(np.int32)
print(f'{a[0]}{a[1]} + {b[0]}{b[1]} = {x[0]}{x[1]}{x[2]} (correct: {c[0]}{c[1]}{c[2]})')
print(f'Wrong predictions: {wrong}, acc = {acc:.4f}')

View File

@@ -6,9 +6,8 @@ from tinygrad.jit import TinyJit
def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=lambda out,y: out.sparse_categorical_crossentropy(y),
transform=lambda x: x, target_transform=lambda x: x, noloss=False):
transform=lambda x: x, target_transform=lambda x: x, noloss=False, allow_jit=True):
@TinyJit
def train_step(x, y):
# network
out = model.forward(x) if hasattr(model, 'forward') else model(x)
@@ -22,6 +21,8 @@ def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=lambda out,y: ou
accuracy = (cat == y).mean()
return loss.realize(), accuracy.realize()
if allow_jit: train_step = TinyJit(train_step)
with Tensor.train():
losses, accuracies = [], []
for i in (t := trange(steps, disable=CI)):