mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Transpose on GPU (#221)
* 2serious * load/save * fixing GPU * added DEBUG * needs BatchNorm or doesn't learn anything * old file not needed * added conv biases * added extra/training.py and checkpoint * assert in test only * save * padding * num_classes * checkpoint * checkpoints for padding * training was broken * merge * rotation augmentation * more aug * needs testing * streamline augment, augment is fast thus bicubic * tidying up * transformer eval * axis=-1 * transpose * test for permutation using torch.movedims * another test * line
This commit is contained in:
@@ -48,7 +48,7 @@ def evaluate(model, X_test, Y_test, num_classes=None, device=Device.CPU, BS=128)
|
||||
Y_test_preds_out = np.zeros(list(Y_test.shape)+[num_classes])
|
||||
for i in trange(len(Y_test)//BS, disable=os.getenv('CI') is not None):
|
||||
Y_test_preds_out[i*BS:(i+1)*BS] = model.forward(Tensor(X_test[i*BS:(i+1)*BS], device=device)).cpu().data
|
||||
Y_test_preds = np.argmax(Y_test_preds_out, axis=len(Y_test.shape))
|
||||
Y_test_preds = np.argmax(Y_test_preds_out, axis=-1)
|
||||
return (Y_test == Y_test_preds).mean()
|
||||
|
||||
if num_classes is None: num_classes = Y_test.max().astype(int)+1
|
||||
|
||||
Reference in New Issue
Block a user