modernize hlb_cifar (#3146)

* modernize hlb_cifar

do more things in Tensor space instead of numpy, clean up dtypes and use more Tensor methods.

* eigens are float64
This commit is contained in:
chenyu
2024-01-16 01:35:11 -05:00
committed by GitHub
parent 2088937206
commit ec5a212b0a

View File

@@ -119,8 +119,7 @@ def train_cifar():
# NOTE: np.linalg.eigh only supports float32 so the whitening layer weights need to be converted to float16 manually
def whitening(X, kernel_size=hyp['net']['kernel_size']):
def _cov(X):
X = X/np.sqrt(X.shape[0] - 1)
return X.T @ X
return (X.T @ X) / (X.shape[0] - 1)
def _patches(data, patch_size=(kernel_size,kernel_size)):
h, w = patch_size
@@ -144,9 +143,11 @@ def train_cifar():
divisor = y.shape[1]
assert isinstance(divisor, int), "only supported int divisor"
y = (1 - label_smoothing)*y + label_smoothing / divisor
if reduction=='none': return -x.log_softmax(axis=1).mul(y).sum(axis=1)
if reduction=='sum': return -x.log_softmax(axis=1).mul(y).sum(axis=1).sum()
return -x.log_softmax(axis=1).mul(y).sum(axis=1).mean()
ret = -x.log_softmax(axis=1).mul(y).sum(axis=1)
if reduction=='none': return ret
if reduction=='sum': return ret.sum()
if reduction=='mean': return ret.mean()
raise NotImplementedError(reduction)
# ========== Preprocessing ==========
# TODO currently this only works for RGB in format of NxCxHxW and pads the HxW
@@ -172,19 +173,19 @@ def train_cifar():
is_even = int(mask_size % 2 == 0)
center_max = shape[-2]-mask_size//2-is_even
center_min = mask_size//2-is_even
center_x = (Tensor.rand(shape[0])*(center_max-center_min)+center_min).floor()
center_y = (Tensor.rand(shape[0])*(center_max-center_min)+center_min).floor()
center_x = Tensor.randint(shape[0], low=center_min, high=center_max)
center_y = Tensor.randint(shape[0], low=center_min, high=center_max)
d_x = Tensor.arange(0, shape[-1]).reshape((1,1,1,shape[-1])) - center_x.reshape((-1,1,1,1))
d_y = Tensor.arange(0, shape[-2]).reshape((1,1,shape[-2],1)) - center_y.reshape((-1,1,1,1))
d_x =(d_x >= -(mask_size // 2) + is_even) * (d_x <= mask_size // 2)
d_y =(d_y >= -(mask_size // 2) + is_even) * (d_y <= mask_size // 2)
d_x = (d_x >= -(mask_size // 2) + is_even) * (d_x <= mask_size // 2)
d_y = (d_y >= -(mask_size // 2) + is_even) * (d_y <= mask_size // 2)
mask = d_y * d_x
return mask
return mask.cast(dtypes.bool)
def random_crop(X:Tensor, crop_size=32):
mask = make_square_mask(X.shape, crop_size)
mask = mask.repeat((1,3,1,1))
X_cropped = Tensor(X.flatten().numpy()[mask.flatten().numpy().astype(bool)])
mask = mask.expand((-1,3,-1,-1))
X_cropped = Tensor(X.numpy()[mask.numpy()])
return X_cropped.reshape((-1, 3, crop_size, crop_size))
def cutmix(X:Tensor, Y:Tensor, mask_size=3):
@@ -192,6 +193,7 @@ def train_cifar():
mask = make_square_mask(X.shape, mask_size)
order = list(range(0, X.shape[0]))
random.shuffle(order)
# NOTE: Memory access fault if use getitem directly
X_patch = Tensor(X.numpy()[order,...])
Y_patch = Tensor(Y.numpy()[order])
X_cutmix = Tensor.where(mask, X_patch, X)
@@ -215,7 +217,7 @@ def train_cifar():
et = time.monotonic()
print(f"shuffling {'training' if is_train else 'test'} dataset in {(et-st)*1e3:.2f} ms ({cnt})")
for i in range(0, X.shape[0], BS):
# pad the last batch
# pad the last batch # TODO: not correct for test
batch_end = min(i+BS, Y.shape[0])
x = Tensor(X[order[batch_end-BS:batch_end],:])
y = Tensor(Y[order[batch_end-BS:batch_end]])
@@ -257,9 +259,9 @@ def train_cifar():
X_train, Y_train, X_test, Y_test = fetch_cifar()
# load data and label into GPU and convert to dtype accordingly
X_train, X_test = X_train.to(device=Device.DEFAULT).float(), X_test.to(device=Device.DEFAULT).float()
Y_train, Y_test = Y_train.to(device=Device.DEFAULT).float(), Y_test.to(device=Device.DEFAULT).float()
Y_train, Y_test = Y_train.to(device=Device.DEFAULT), Y_test.to(device=Device.DEFAULT)
# one-hot encode labels
Y_train, Y_test = Tensor.eye(10)[Y_train.cast(dtypes.int32)], Tensor.eye(10)[Y_test.cast(dtypes.int32)]
Y_train, Y_test = Y_train.one_hot(10), Y_test.one_hot(10)
# preprocess data
X_train, X_test = X_train.sequential(transform), X_test.sequential(transform)
@@ -294,13 +296,13 @@ def train_cifar():
initial_div_factor = hyp['opt']['initial_div_factor']
final_lr_ratio = hyp['opt']['final_lr_ratio']
pct_start = hyp['opt']['percent_start']
lr_sched_bias = OneCycleLR(opt_bias, max_lr=hyp['opt']['bias_lr'] , pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=STEPS)
lr_sched_bias = OneCycleLR(opt_bias, max_lr=hyp['opt']['bias_lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=STEPS)
lr_sched_non_bias = OneCycleLR(opt_non_bias, max_lr=hyp['opt']['non_bias_lr'], pct_start=pct_start, div_factor=initial_div_factor, final_div_factor=1./(initial_div_factor*final_lr_ratio), total_steps=STEPS)
def train_step(model, optimizer, lr_scheduler, X, Y):
out = model(X)
loss_batchsize_scaler = 512/BS
loss = cross_entropy(out, Y, reduction='none' ,label_smoothing=hyp['opt']['label_smoothing']).mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler'])
loss = cross_entropy(out, Y, reduction='none', label_smoothing=hyp['opt']['label_smoothing']).mul(hyp['opt']['loss_scale_scaler']*loss_batchsize_scaler).sum().div(hyp['opt']['loss_scale_scaler'])
if not getenv("DISABLE_BACKWARD"):
# index 0 for bias and 1 for non-bias
@@ -351,8 +353,7 @@ def train_cifar():
with Tensor.train():
st = time.monotonic()
while i <= STEPS:
if i%getenv("EVAL_STEPS", STEPS) == 0 and i > 1:
st_eval = time.monotonic()
if i % getenv("EVAL_STEPS", STEPS) == 0 and i > 1:
# Use Tensor.training = False here actually bricks batchnorm, even with track_running_stats=True
corrects = []
corrects_ema = []
@@ -399,7 +400,7 @@ def train_cifar():
print(f"eval {correct_sum}/{correct_len} {acc:.2f}%, {(sum(losses)/len(losses)):7.2f} val_loss STEP={i} (in {(time.monotonic()-st)*1e3:.2f} ms)")
if model_ema: print(f"eval ema {correct_sum_ema}/{correct_len_ema} {acc_ema:.2f}%, {(sum(losses_ema)/len(losses_ema)):7.2f} val_loss STEP={i}")
if STEPS == 0 or i==STEPS: break
if STEPS == 0 or i == STEPS: break
X, Y = next(batcher)
if getenv("DIST"):
X, Y = X.chunk(world_size, 0)[rank], Y.chunk(world_size, 0)[rank]