From ec5a212b0a116aa7482cd6c9e160e27a3097683d Mon Sep 17 00:00:00 2001 From: chenyu Date: Tue, 16 Jan 2024 01:35:11 -0500 Subject: [PATCH] modernize hlb_cifar (#3146) * modernize hlb_cifar do more things in Tensor space instead of numpy, clean up dtypes and use more Tensor methods. * eigens are float64 --- examples/hlb_cifar10.py | 41 +++++++++++++++++++++-------------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/examples/hlb_cifar10.py b/examples/hlb_cifar10.py index 6f17fc1967..6de9799967 100644 --- a/examples/hlb_cifar10.py +++ b/examples/hlb_cifar10.py @@ -119,8 +119,7 @@ def train_cifar(): # NOTE: np.linalg.eigh only supports float32 so the whitening layer weights need to be converted to float16 manually def whitening(X, kernel_size=hyp['net']['kernel_size']): def _cov(X): - X = X/np.sqrt(X.shape[0] - 1) - return X.T @ X + return (X.T @ X) / (X.shape[0] - 1) def _patches(data, patch_size=(kernel_size,kernel_size)): h, w = patch_size @@ -144,9 +143,11 @@ def train_cifar(): divisor = y.shape[1] assert isinstance(divisor, int), "only supported int divisor" y = (1 - label_smoothing)*y + label_smoothing / divisor - if reduction=='none': return -x.log_softmax(axis=1).mul(y).sum(axis=1) - if reduction=='sum': return -x.log_softmax(axis=1).mul(y).sum(axis=1).sum() - return -x.log_softmax(axis=1).mul(y).sum(axis=1).mean() + ret = -x.log_softmax(axis=1).mul(y).sum(axis=1) + if reduction=='none': return ret + if reduction=='sum': return ret.sum() + if reduction=='mean': return ret.mean() + raise NotImplementedError(reduction) # ========== Preprocessing ========== # TODO currently this only works for RGB in format of NxCxHxW and pads the HxW @@ -172,19 +173,19 @@ def train_cifar(): is_even = int(mask_size % 2 == 0) center_max = shape[-2]-mask_size//2-is_even center_min = mask_size//2-is_even - center_x = (Tensor.rand(shape[0])*(center_max-center_min)+center_min).floor() - center_y = (Tensor.rand(shape[0])*(center_max-center_min)+center_min).floor() + center_x = Tensor.randint(shape[0], low=center_min, high=center_max) + center_y = Tensor.randint(shape[0], low=center_min, high=center_max) d_x = Tensor.arange(0, shape[-1]).reshape((1,1,1,shape[-1])) - center_x.reshape((-1,1,1,1)) d_y = Tensor.arange(0, shape[-2]).reshape((1,1,shape[-2],1)) - center_y.reshape((-1,1,1,1)) - d_x =(d_x >= -(mask_size // 2) + is_even) * (d_x <= mask_size // 2) - d_y =(d_y >= -(mask_size // 2) + is_even) * (d_y <= mask_size // 2) + d_x = (d_x >= -(mask_size // 2) + is_even) * (d_x <= mask_size // 2) + d_y = (d_y >= -(mask_size // 2) + is_even) * (d_y <= mask_size // 2) mask = d_y * d_x - return mask + return mask.cast(dtypes.bool) def random_crop(X:Tensor, crop_size=32): mask = make_square_mask(X.shape, crop_size) - mask = mask.repeat((1,3,1,1)) - X_cropped = Tensor(X.flatten().numpy()[mask.flatten().numpy().astype(bool)]) + mask = mask.expand((-1,3,-1,-1)) + X_cropped = Tensor(X.numpy()[mask.numpy()]) return X_cropped.reshape((-1, 3, crop_size, crop_size)) def cutmix(X:Tensor, Y:Tensor, mask_size=3): @@ -192,6 +193,7 @@ def train_cifar(): mask = make_square_mask(X.shape, mask_size) order = list(range(0, X.shape[0])) random.shuffle(order) + # NOTE: Memory access fault if use getitem directly X_patch = Tensor(X.numpy()[order,...]) Y_patch = Tensor(Y.numpy()[order]) X_cutmix = Tensor.where(mask, X_patch, X) @@ -215,7 +217,7 @@ def train_cifar(): et = time.monotonic() print(f"shuffling {'training' if is_train else 'test'} dataset in {(et-st)*1e3:.2f} ms ({cnt})") for i in range(0, X.shape[0], BS): - # pad the last batch + # pad the last batch # TODO: not correct for test batch_end = min(i+BS, Y.shape[0]) x = Tensor(X[order[batch_end-BS:batch_end],:]) y = Tensor(Y[order[batch_end-BS:batch_end]]) @@ -257,9 +259,9 @@ def train_cifar(): X_train, Y_train, X_test, Y_test = fetch_cifar() # load data and label into GPU and convert to dtype accordingly X_train, X_test = X_train.to(device=Device.DEFAULT).float(), X_test.to(device=Device.DEFAULT).float() - Y_train, Y_test = Y_train.to(device=Device.DEFAULT).float(), Y_test.to(device=Device.DEFAULT).float() + Y_train, Y_test = Y_train.to(device=Device.DEFAULT), Y_test.to(device=Device.DEFAULT) # one-hot encode labels - Y_train, Y_test = Tensor.eye(10)[Y_train.cast(dtypes.int32)], Tensor.eye(10)[Y_test.cast(dtypes.int32)] + Y_train, Y_test = Y_train.one_hot(10), Y_test.one_hot(10) # preprocess data X_train, X_test = X_train.sequential(transform), X_test.sequential(transform) @@ -294,13 +296,13 @@ def train_cifar(): initial_div_factor = hyp['opt']['initial_div_factor'] final_lr_ratio = hyp['opt']['final_lr_ratio'] pct_start = hyp['opt']['percent_start'] - lr_sched_bias = OneCycleLR(opt_bias, max_lr=hyp['opt']['bias_lr'] , pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=STEPS) + lr_sched_bias = OneCycleLR(opt_bias, max_lr=hyp['opt']['bias_lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=STEPS) lr_sched_non_bias = OneCycleLR(opt_non_bias, max_lr=hyp['opt']['non_bias_lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=STEPS) def train_step(model, optimizer, lr_scheduler, X, Y): out = model(X) loss_batchsize_scaler = 512/BS - loss = cross_entropy(out, Y, reduction='none' ,label_smoothing=hyp['opt']['label_smoothing']).mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler']) + loss = cross_entropy(out, Y, reduction='none', label_smoothing=hyp['opt']['label_smoothing']).mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler']) if not getenv("DISABLE_BACKWARD"): # index 0 for bias and 1 for non-bias @@ -351,8 +353,7 @@ def train_cifar(): with Tensor.train(): st = time.monotonic() while i <= STEPS: - if i%getenv("EVAL_STEPS", STEPS) == 0 and i > 1: - st_eval = time.monotonic() + if i % getenv("EVAL_STEPS", STEPS) == 0 and i > 1: # Use Tensor.training = False here actually bricks batchnorm, even with track_running_stats=True corrects = [] corrects_ema = [] @@ -399,7 +400,7 @@ def train_cifar(): print(f"eval {correct_sum}/{correct_len} {acc:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i} (in {(time.monotonic()-st)*1e3:.2f} ms)") if model_ema: print(f"eval ema {correct_sum_ema}/{correct_len_ema} {acc_ema:.2f}%, {(sum(losses_ema)/len(losses_ema)):7.2f} val_loss STEP={i}") - if STEPS == 0 or i==STEPS: break + if STEPS == 0 or i == STEPS: break X, Y = next(batcher) if getenv("DIST"): X, Y = X.chunk(world_size, 0)[rank], Y.chunk(world_size, 0)[rank]