mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
minor hlb_cifar cleanups (#3208)
mostly cosmetic. LATEBEAM=4 single 7900xtx 59.2 seconds
This commit is contained in:
@@ -7,10 +7,10 @@ import random, time
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
from extra.datasets import fetch_cifar, cifar_mean, cifar_std
|
||||
from extra.lr_scheduler import OneCycleLR
|
||||
from tinygrad import nn, dtypes, Tensor, Device, GlobalCounters, TinyJit
|
||||
from tinygrad.nn.state import get_state_dict, get_parameters
|
||||
from tinygrad.nn import optim
|
||||
from extra.lr_scheduler import OneCycleLR
|
||||
from tinygrad.helpers import Context, BEAM, WINO, getenv
|
||||
|
||||
BS, EVAL_BS, STEPS = getenv("BS", 512), getenv('EVAL_BS', 500), getenv("STEPS", 1000)
|
||||
@@ -18,7 +18,7 @@ GPUS = [f'{Device.DEFAULT}:{i}' for i in range(getenv("GPUS", 1))]
|
||||
assert BS % len(GPUS) == 0, f"{BS=} is not a multiple of {len(GPUS)=}"
|
||||
for x in GPUS: Device[x]
|
||||
|
||||
if getenv("HALF", 0):
|
||||
if getenv("HALF"):
|
||||
dtypes.default_float = dtypes.float16
|
||||
np_dtype = np.float16
|
||||
else:
|
||||
@@ -66,14 +66,14 @@ class SpeedyResNet:
|
||||
ConvGroup(256, 512),
|
||||
lambda x: x.max((2,3)),
|
||||
nn.Linear(512, 10, bias=False),
|
||||
lambda x: x.mul(1./9)
|
||||
lambda x: x / 9.,
|
||||
]
|
||||
|
||||
def __call__(self, x, training=True):
|
||||
# pad to 32x32 because whitening conv creates 31x31 images that are awfully slow to compute with
|
||||
# TODO: remove the pad but instead let the kernel optimize itself
|
||||
forward = lambda x: x.conv2d(self.whitening).pad2d((1,0,0,1)).sequential(self.net)
|
||||
return forward(x) if training else forward(x)*0.5 + forward(x[..., ::-1])*0.5
|
||||
return forward(x) if training else (forward(x) + forward(x[..., ::-1])) / 2.
|
||||
|
||||
# hyper-parameters were exactly the same as the original repo
|
||||
bias_scaler = 58
|
||||
@@ -108,8 +108,8 @@ hyp = {
|
||||
def train_cifar():
|
||||
|
||||
def set_seed(seed):
|
||||
Tensor.manual_seed(getenv('SEED', seed))
|
||||
random.seed(getenv('SEED', seed))
|
||||
Tensor.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
# ========== Model ==========
|
||||
# NOTE: np.linalg.eigh only supports float32 so the whitening layer weights need to be converted to float16 manually
|
||||
@@ -174,14 +174,14 @@ def train_cifar():
|
||||
random.shuffle(order)
|
||||
X_patch = Tensor(X.numpy()[order,...])
|
||||
Y_patch = Tensor(Y.numpy()[order])
|
||||
X_cutmix = Tensor.where(mask, X_patch, X)
|
||||
X_cutmix = mask.where(X_patch, X)
|
||||
mix_portion = float(mask_size**2)/(X.shape[-2]*X.shape[-1])
|
||||
Y_cutmix = mix_portion * Y_patch + (1. - mix_portion) * Y
|
||||
return X_cutmix, Y_cutmix
|
||||
|
||||
# the operations that remain inside batch fetcher is the ones that involves random operations
|
||||
def fetch_batches(X_in:Tensor, Y_in:Tensor, BS:int, is_train:bool):
|
||||
step, cnt = 0, 0
|
||||
step, epoch = 0, 0
|
||||
while True:
|
||||
st = time.monotonic()
|
||||
X, Y = X_in, Y_in
|
||||
@@ -192,13 +192,13 @@ def train_cifar():
|
||||
if getenv("RANDOM_CROP", 1):
|
||||
X = random_crop(X, crop_size=32)
|
||||
if getenv("RANDOM_FLIP", 1):
|
||||
X = Tensor.where(Tensor.rand(X.shape[0],1,1,1) < 0.5, X[..., ::-1], X) # flip LR
|
||||
X = (Tensor.rand(X.shape[0],1,1,1) < 0.5).where(X.flip(-1), X) # flip LR
|
||||
if getenv("CUTMIX", 1):
|
||||
if step >= hyp['net']['cutmix_steps']:
|
||||
X, Y = cutmix(X, Y, mask_size=hyp['net']['cutmix_size'])
|
||||
X, Y = X.numpy(), Y.numpy()
|
||||
et = time.monotonic()
|
||||
print(f"shuffling {'training' if is_train else 'test'} dataset in {(et-st)*1e3:.2f} ms ({cnt})")
|
||||
print(f"shuffling {'training' if is_train else 'test'} dataset in {(et-st)*1e3:.2f} ms ({epoch=})")
|
||||
for i in range(0, X.shape[0], BS):
|
||||
# pad the last batch # TODO: not correct for test
|
||||
batch_end = min(i+BS, Y.shape[0])
|
||||
@@ -206,7 +206,7 @@ def train_cifar():
|
||||
y = Tensor(Y[order[batch_end-BS:batch_end]])
|
||||
step += 1
|
||||
yield x, y
|
||||
cnt += 1
|
||||
epoch += 1
|
||||
if not is_train: break
|
||||
|
||||
transform = [
|
||||
@@ -232,7 +232,7 @@ def train_cifar():
|
||||
net_ema_param.assign(net_ema_param.detach()*decay + net_param.detach()*(1.-decay)).realize()
|
||||
Tensor.no_grad = False
|
||||
|
||||
set_seed(hyp['seed'])
|
||||
set_seed(getenv('SEED', hyp['seed']))
|
||||
|
||||
X_train, Y_train, X_test, Y_test = fetch_cifar()
|
||||
# load data and label into GPU and convert to dtype accordingly
|
||||
|
||||
Reference in New Issue
Block a user