From 9b1c3cd9ca12a92af61a402308ea10b2299e843f Mon Sep 17 00:00:00 2001 From: George Hotz Date: Wed, 18 Oct 2023 01:11:08 +0000 Subject: [PATCH] hlb_cifar: support EVAL_STEPS=1000, print when dataset is shuffled --- examples/hlb_cifar10.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/examples/hlb_cifar10.py b/examples/hlb_cifar10.py index 2f78777ec0..dbdd09aef2 100644 --- a/examples/hlb_cifar10.py +++ b/examples/hlb_cifar10.py @@ -93,7 +93,7 @@ def train_cifar(): 'bias_decay': 1.08 * 6.45e-4 * BS/bias_scaler, 'non_bias_decay': 1.08 * 6.45e-4 * BS, 'final_lr_ratio': 0.025, - 'initial_div_factor': 1e16, + 'initial_div_factor': 1e16, 'label_smoothing': 0.20, 'momentum': 0.85, 'percent_start': 0.23, @@ -102,7 +102,7 @@ def train_cifar(): 'net': { 'kernel_size': 2, # kernel size for the whitening layer 'cutmix_size': 3, - 'cutmix_steps': 499, + 'cutmix_steps': 499, 'pad_amount': 2 }, 'ema': { @@ -187,7 +187,7 @@ def train_cifar(): return X_cropped.reshape((-1, 3, crop_size, crop_size)) - def cutmix(X, Y, mask_size=3): + def cutmix(X:Tensor, Y:Tensor, mask_size=3): # fill the square with randomly selected images from the same batch mask = make_square_mask(X.shape, mask_size) order = list(range(0, X.shape[0])) @@ -200,9 +200,10 @@ def train_cifar(): return X_cutmix, Y_cutmix # the operations that remain inside batch fetcher is the ones that involves random operations - def fetch_batches(X_in, Y_in, BS, is_train): + def fetch_batches(X_in:Tensor, Y_in:Tensor, BS:int, is_train:bool): step = 0 while True: + st = time.monotonic() X, Y = X_in, Y_in order = list(range(0, X.shape[0])) random.shuffle(order) @@ -211,6 +212,8 @@ def train_cifar(): X = Tensor.where(Tensor.rand(X.shape[0],1,1,1) < 0.5, X[..., ::-1], X) # flip LR if step >= hyp['net']['cutmix_steps']: X, Y = cutmix(X, Y, mask_size=hyp['net']['cutmix_size']) X, Y = X.numpy(), Y.numpy() + et = time.monotonic() + print(f"shuffling {'training' if is_train else 'test'} dataset in {(et-st)*1e3:.2f} ms") for i in range(0, X.shape[0], BS): # pad the last batch batch_end = min(i+BS, Y.shape[0]) @@ -344,8 +347,9 @@ def train_cifar(): i = 0 batcher = fetch_batches(X_train, Y_train, BS=BS, is_train=True) with Tensor.train(): + st = time.monotonic() while i <= STEPS: - if i%100 == 0 and i > 1: + if i%getenv("EVAL_STEPS", 100) == 0 and i > 1: # Use Tensor.training = False here actually bricks batchnorm, even with track_running_stats=True corrects = [] corrects_ema = [] @@ -399,7 +403,6 @@ def train_cifar(): if getenv("DIST"): X, Y = X.chunk(world_size, 0)[rank], Y.chunk(world_size, 0)[rank] GlobalCounters.reset() - st = time.monotonic() loss = train_step_jitted(model, [opt_bias, opt_non_bias], [lr_sched_bias, lr_sched_non_bias], X, Y) et = time.monotonic() loss_cpu = loss.numpy() @@ -413,6 +416,7 @@ def train_cifar(): print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used/1e9:.2f} GB used, {GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS") else: print(f"{i:3d} {(cl-st)*1000.0:7.2f} ms run, {(et-st)*1000.0:7.2f} ms python, {(cl-et)*1000.0:7.2f} ms CL, {loss_cpu:7.2f} loss, {opt_non_bias.lr.numpy()[0]:.6f} LR, {world_size*GlobalCounters.mem_used/1e9:.2f} GB used, {world_size*GlobalCounters.global_ops*1e-9/(cl-st):9.2f} GFLOPS") + st = cl i += 1 if __name__ == "__main__":