mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 15:38:29 -05:00
hotfix: examples/transformer.py
This commit is contained in:
@@ -28,7 +28,8 @@ if __name__ == "__main__":
|
||||
lr = 0.003
|
||||
for i in range(10):
|
||||
optim = Adam(get_parameters(model), lr=lr)
|
||||
train(model, X_train, Y_train, optim, 50, BS=64)
|
||||
# TODO: BUG! why doesn't the JIT work here?
|
||||
train(model, X_train, Y_train, optim, 50, BS=64, allow_jit=False)
|
||||
acc, Y_test_preds = evaluate(model, X_test, Y_test, num_classes=10, return_predict=True)
|
||||
lr /= 1.2
|
||||
print(f'reducing lr to {lr:.4f}')
|
||||
@@ -37,6 +38,6 @@ if __name__ == "__main__":
|
||||
for k in range(len(Y_test_preds)):
|
||||
if (Y_test_preds[k] != Y_test[k]).any():
|
||||
wrong+=1
|
||||
a,b,c,x = X_test[k,:2], X_test[k,2:4], Y_test[k,-3:], Y_test_preds[k,-3:]
|
||||
a,b,c,x = X_test[k,:2].astype(np.int32), X_test[k,2:4].astype(np.int32), Y_test[k,-3:].astype(np.int32), Y_test_preds[k,-3:].astype(np.int32)
|
||||
print(f'{a[0]}{a[1]} + {b[0]}{b[1]} = {x[0]}{x[1]}{x[2]} (correct: {c[0]}{c[1]}{c[2]})')
|
||||
print(f'Wrong predictions: {wrong}, acc = {acc:.4f}')
|
||||
|
||||
@@ -6,9 +6,8 @@ from tinygrad.jit import TinyJit
|
||||
|
||||
|
||||
def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=lambda out,y: out.sparse_categorical_crossentropy(y),
|
||||
transform=lambda x: x, target_transform=lambda x: x, noloss=False):
|
||||
transform=lambda x: x, target_transform=lambda x: x, noloss=False, allow_jit=True):
|
||||
|
||||
@TinyJit
|
||||
def train_step(x, y):
|
||||
# network
|
||||
out = model.forward(x) if hasattr(model, 'forward') else model(x)
|
||||
@@ -22,6 +21,8 @@ def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=lambda out,y: ou
|
||||
accuracy = (cat == y).mean()
|
||||
return loss.realize(), accuracy.realize()
|
||||
|
||||
if allow_jit: train_step = TinyJit(train_step)
|
||||
|
||||
with Tensor.train():
|
||||
losses, accuracies = [], []
|
||||
for i in (t := trange(steps, disable=CI)):
|
||||
|
||||
Reference in New Issue
Block a user