mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
Evaluation in Transformers (#218)
* 2serious * load/save * fixing GPU * added DEBUG * needs BatchNorm or doesn't learn anything * old file not needed * added conv biases * added extra/training.py and checkpoint * assert in test only * save * padding * num_classes * checkpoint * checkpoints for padding * training was broken * merge * rotation augmentation * more aug * needs testing * streamline augment, augment is fast thus bicubic * tidying up * transformer eval
This commit is contained in:
@@ -87,19 +87,15 @@ class Transformer:
|
||||
x = t(x)
|
||||
x = x.reshape(shape=(-1, x.shape[-1])).dot(self.final).logsoftmax()
|
||||
return x.reshape(shape=(bs, -1, x.shape[-1]))
|
||||
|
||||
|
||||
from tinygrad.optim import Adam
|
||||
if __name__ == "__main__":
|
||||
model = Transformer(10, 6, 2, 128, 4)
|
||||
|
||||
#in1 = Tensor.zeros(20, 6, 128)
|
||||
#ret = model.forward(in1)
|
||||
#print(ret.shape)
|
||||
|
||||
X_train, Y_train, X_test, Y_test = make_dataset()
|
||||
optim = Adam(get_parameters(model), lr=0.001)
|
||||
train(model, X_train, Y_train, optim, 100)
|
||||
|
||||
|
||||
train(model, X_train, Y_train, optim, 500)
|
||||
|
||||
evaluate(model, X_test, Y_test, num_classes=10)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user