mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Evaluation in Transformers (#218)
* 2serious * load/save * fixing GPU * added DEBUG * needs BatchNorm or doesn't learn anything * old file not needed * added conv biases * added extra/training.py and checkpoint * assert in test only * save * padding * num_classes * checkpoint * checkpoints for padding * training was broken * merge * rotation augmentation * more aug * needs testing * streamline augment, augment is fast thus bicubic * tidying up * transformer eval
This commit is contained in:
@@ -87,19 +87,15 @@ class Transformer:
|
||||
x = t(x)
|
||||
x = x.reshape(shape=(-1, x.shape[-1])).dot(self.final).logsoftmax()
|
||||
return x.reshape(shape=(bs, -1, x.shape[-1]))
|
||||
|
||||
|
||||
from tinygrad.optim import Adam
|
||||
if __name__ == "__main__":
|
||||
model = Transformer(10, 6, 2, 128, 4)
|
||||
|
||||
#in1 = Tensor.zeros(20, 6, 128)
|
||||
#ret = model.forward(in1)
|
||||
#print(ret.shape)
|
||||
|
||||
X_train, Y_train, X_test, Y_test = make_dataset()
|
||||
optim = Adam(get_parameters(model), lr=0.001)
|
||||
train(model, X_train, Y_train, optim, 100)
|
||||
|
||||
|
||||
train(model, X_train, Y_train, optim, 500)
|
||||
|
||||
evaluate(model, X_test, Y_test, num_classes=10)
|
||||
|
||||
|
||||
|
||||
@@ -44,10 +44,10 @@ def train(model, X_train, Y_train, optim, steps, BS=128, device=Device.CPU, loss
|
||||
|
||||
def evaluate(model, X_test, Y_test, num_classes=None, device=Device.CPU, BS=128):
|
||||
def numpy_eval(num_classes):
|
||||
Y_test_preds_out = np.zeros((len(Y_test),num_classes))
|
||||
Y_test_preds_out = np.zeros(list(Y_test.shape)+[num_classes])
|
||||
for i in trange(len(Y_test)//BS, disable=os.getenv('CI') is not None):
|
||||
Y_test_preds_out[i*BS:(i+1)*BS] = model.forward(Tensor(X_test[i*BS:(i+1)*BS], device=device)).cpu().data
|
||||
Y_test_preds = np.argmax(Y_test_preds_out, axis=1)
|
||||
Y_test_preds = np.argmax(Y_test_preds_out, axis=len(Y_test.shape))
|
||||
return (Y_test == Y_test_preds).mean()
|
||||
|
||||
if num_classes is None: num_classes = Y_test.max().astype(int)+1
|
||||
|
||||
Reference in New Issue
Block a user