mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
modernize hlb_cifar (#3146)
* modernize hlb_cifar do more things in Tensor space instead of numpy, clean up dtypes and use more Tensor methods. * eigens are float64
This commit is contained in:
@@ -119,8 +119,7 @@ def train_cifar():
|
||||
# NOTE: np.linalg.eigh only supports float32 so the whitening layer weights need to be converted to float16 manually
|
||||
def whitening(X, kernel_size=hyp['net']['kernel_size']):
|
||||
def _cov(X):
|
||||
X = X/np.sqrt(X.shape[0] - 1)
|
||||
return X.T @ X
|
||||
return (X.T @ X) / (X.shape[0] - 1)
|
||||
|
||||
def _patches(data, patch_size=(kernel_size,kernel_size)):
|
||||
h, w = patch_size
|
||||
@@ -144,9 +143,11 @@ def train_cifar():
|
||||
divisor = y.shape[1]
|
||||
assert isinstance(divisor, int), "only supported int divisor"
|
||||
y = (1 - label_smoothing)*y + label_smoothing / divisor
|
||||
if reduction=='none': return -x.log_softmax(axis=1).mul(y).sum(axis=1)
|
||||
if reduction=='sum': return -x.log_softmax(axis=1).mul(y).sum(axis=1).sum()
|
||||
return -x.log_softmax(axis=1).mul(y).sum(axis=1).mean()
|
||||
ret = -x.log_softmax(axis=1).mul(y).sum(axis=1)
|
||||
if reduction=='none': return ret
|
||||
if reduction=='sum': return ret.sum()
|
||||
if reduction=='mean': return ret.mean()
|
||||
raise NotImplementedError(reduction)
|
||||
|
||||
# ========== Preprocessing ==========
|
||||
# TODO currently this only works for RGB in format of NxCxHxW and pads the HxW
|
||||
@@ -172,19 +173,19 @@ def train_cifar():
|
||||
is_even = int(mask_size % 2 == 0)
|
||||
center_max = shape[-2]-mask_size//2-is_even
|
||||
center_min = mask_size//2-is_even
|
||||
center_x = (Tensor.rand(shape[0])*(center_max-center_min)+center_min).floor()
|
||||
center_y = (Tensor.rand(shape[0])*(center_max-center_min)+center_min).floor()
|
||||
center_x = Tensor.randint(shape[0], low=center_min, high=center_max)
|
||||
center_y = Tensor.randint(shape[0], low=center_min, high=center_max)
|
||||
d_x = Tensor.arange(0, shape[-1]).reshape((1,1,1,shape[-1])) - center_x.reshape((-1,1,1,1))
|
||||
d_y = Tensor.arange(0, shape[-2]).reshape((1,1,shape[-2],1)) - center_y.reshape((-1,1,1,1))
|
||||
d_x =(d_x >= -(mask_size // 2) + is_even) * (d_x <= mask_size // 2)
|
||||
d_y =(d_y >= -(mask_size // 2) + is_even) * (d_y <= mask_size // 2)
|
||||
d_x = (d_x >= -(mask_size // 2) + is_even) * (d_x <= mask_size // 2)
|
||||
d_y = (d_y >= -(mask_size // 2) + is_even) * (d_y <= mask_size // 2)
|
||||
mask = d_y * d_x
|
||||
return mask
|
||||
return mask.cast(dtypes.bool)
|
||||
|
||||
def random_crop(X:Tensor, crop_size=32):
|
||||
mask = make_square_mask(X.shape, crop_size)
|
||||
mask = mask.repeat((1,3,1,1))
|
||||
X_cropped = Tensor(X.flatten().numpy()[mask.flatten().numpy().astype(bool)])
|
||||
mask = mask.expand((-1,3,-1,-1))
|
||||
X_cropped = Tensor(X.numpy()[mask.numpy()])
|
||||
return X_cropped.reshape((-1, 3, crop_size, crop_size))
|
||||
|
||||
def cutmix(X:Tensor, Y:Tensor, mask_size=3):
|
||||
@@ -192,6 +193,7 @@ def train_cifar():
|
||||
mask = make_square_mask(X.shape, mask_size)
|
||||
order = list(range(0, X.shape[0]))
|
||||
random.shuffle(order)
|
||||
# NOTE: Memory access fault if use getitem directly
|
||||
X_patch = Tensor(X.numpy()[order,...])
|
||||
Y_patch = Tensor(Y.numpy()[order])
|
||||
X_cutmix = Tensor.where(mask, X_patch, X)
|
||||
@@ -215,7 +217,7 @@ def train_cifar():
|
||||
et = time.monotonic()
|
||||
print(f"shuffling {'training' if is_train else 'test'} dataset in {(et-st)*1e3:.2f} ms ({cnt})")
|
||||
for i in range(0, X.shape[0], BS):
|
||||
# pad the last batch
|
||||
# pad the last batch # TODO: not correct for test
|
||||
batch_end = min(i+BS, Y.shape[0])
|
||||
x = Tensor(X[order[batch_end-BS:batch_end],:])
|
||||
y = Tensor(Y[order[batch_end-BS:batch_end]])
|
||||
@@ -257,9 +259,9 @@ def train_cifar():
|
||||
X_train, Y_train, X_test, Y_test = fetch_cifar()
|
||||
# load data and label into GPU and convert to dtype accordingly
|
||||
X_train, X_test = X_train.to(device=Device.DEFAULT).float(), X_test.to(device=Device.DEFAULT).float()
|
||||
Y_train, Y_test = Y_train.to(device=Device.DEFAULT).float(), Y_test.to(device=Device.DEFAULT).float()
|
||||
Y_train, Y_test = Y_train.to(device=Device.DEFAULT), Y_test.to(device=Device.DEFAULT)
|
||||
# one-hot encode labels
|
||||
Y_train, Y_test = Tensor.eye(10)[Y_train.cast(dtypes.int32)], Tensor.eye(10)[Y_test.cast(dtypes.int32)]
|
||||
Y_train, Y_test = Y_train.one_hot(10), Y_test.one_hot(10)
|
||||
# preprocess data
|
||||
X_train, X_test = X_train.sequential(transform), X_test.sequential(transform)
|
||||
|
||||
@@ -294,13 +296,13 @@ def train_cifar():
|
||||
initial_div_factor = hyp['opt']['initial_div_factor']
|
||||
final_lr_ratio = hyp['opt']['final_lr_ratio']
|
||||
pct_start = hyp['opt']['percent_start']
|
||||
lr_sched_bias = OneCycleLR(opt_bias, max_lr=hyp['opt']['bias_lr'] , pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=STEPS)
|
||||
lr_sched_bias = OneCycleLR(opt_bias, max_lr=hyp['opt']['bias_lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=STEPS)
|
||||
lr_sched_non_bias = OneCycleLR(opt_non_bias, max_lr=hyp['opt']['non_bias_lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=STEPS)
|
||||
|
||||
def train_step(model, optimizer, lr_scheduler, X, Y):
|
||||
out = model(X)
|
||||
loss_batchsize_scaler = 512/BS
|
||||
loss = cross_entropy(out, Y, reduction='none' ,label_smoothing=hyp['opt']['label_smoothing']).mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler'])
|
||||
loss = cross_entropy(out, Y, reduction='none', label_smoothing=hyp['opt']['label_smoothing']).mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler'])
|
||||
|
||||
if not getenv("DISABLE_BACKWARD"):
|
||||
# index 0 for bias and 1 for non-bias
|
||||
@@ -351,8 +353,7 @@ def train_cifar():
|
||||
with Tensor.train():
|
||||
st = time.monotonic()
|
||||
while i <= STEPS:
|
||||
if i%getenv("EVAL_STEPS", STEPS) == 0 and i > 1:
|
||||
st_eval = time.monotonic()
|
||||
if i % getenv("EVAL_STEPS", STEPS) == 0 and i > 1:
|
||||
# Use Tensor.training = False here actually bricks batchnorm, even with track_running_stats=True
|
||||
corrects = []
|
||||
corrects_ema = []
|
||||
@@ -399,7 +400,7 @@ def train_cifar():
|
||||
print(f"eval {correct_sum}/{correct_len} {acc:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i} (in {(time.monotonic()-st)*1e3:.2f} ms)")
|
||||
if model_ema: print(f"eval ema {correct_sum_ema}/{correct_len_ema} {acc_ema:.2f}%, {(sum(losses_ema)/len(losses_ema)):7.2f} val_loss STEP={i}")
|
||||
|
||||
if STEPS == 0 or i==STEPS: break
|
||||
if STEPS == 0 or i == STEPS: break
|
||||
X, Y = next(batcher)
|
||||
if getenv("DIST"):
|
||||
X, Y = X.chunk(world_size, 0)[rank], Y.chunk(world_size, 0)[rank]
|
||||
|
||||
Reference in New Issue
Block a user