From efe64ebeaf0c51866fa0b93f4d221a75d054adcc Mon Sep 17 00:00:00 2001 From: Francis Lata Date: Wed, 22 Jan 2025 09:56:38 -0800 Subject: [PATCH] enable lr scheduler and fix benchmark timing --- examples/mlperf/model_train.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 3256e30598..fd64b16512 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -377,7 +377,6 @@ def train_retinanet(): return x.shard(GPUS, axis=0).realize(), y_bboxes.shard(GPUS, axis=0), y_labels.shard(GPUS, axis=0), matches.shard(GPUS, axis=0), anchors.shard(GPUS, axis=0), cookie def _create_lr_scheduler(optim, start_iter, warmup_iters, warmup_factor): - # TODO: refactor this a bit more so we don't have to recreate it, unlike what MLPerf script is doing def _lr_lambda(e): e = e + start_iter if e >= warmup_iters: return 1.0 @@ -394,7 +393,7 @@ def train_retinanet(): loss.backward() optim.step() - # lr_scheduler.step() + lr_scheduler.step() return loss.realize(), losses @@ -435,6 +434,8 @@ def train_retinanet(): for p in params: p.to_(GPUS) + step_times, start_epoch = [], 0 + # ** optimizer ** optim = Adam(params, lr=lr) @@ -442,21 +443,17 @@ def train_retinanet(): train_dataset = COCO(download_dataset(BASE_DIR, "train")) val_dataset = COCO(download_dataset(BASE_DIR, "validation")) + # ** lr scheduler ** config["steps_in_train_epoch"] = steps_in_train_epoch = round_up(len(train_dataset.imgs.keys()), bs) // bs - step_times = [] + start_iter, warmup_iters = start_epoch * steps_in_train_epoch, lr_warmup_epochs * steps_in_train_epoch + lr_scheduler = _create_lr_scheduler(optim, start_iter, warmup_iters, lr_warmup_factor) # ** training loop ** - for e in range(1, num_epochs + 1): + for e in range(start_epoch, num_epochs): train_dataloader = batch_load_retinanet(train_dataset, False, Path(BASE_DIR), batch_size=bs, seed=seed) it = iter(tqdm(train_dataloader, total=steps_in_train_epoch, desc=f"epoch {e}", disable=BENCHMARK)) i, proc = 0, _data_get(it) - # if e < LR_WARMUP_EPOCHS: - # start_iter, warmup_iters = e * train_dataset_len, LR_WARMUP_EPOCHS * train_dataset_len - # lr_scheduler = _create_lr_scheduler(optim, start_iter, warmup_iters, LR_WARMUP_FACTOR) - # else: lr_scheduler = None - lr_scheduler = None - prev_cookies = [] st = time.perf_counter() @@ -484,7 +481,7 @@ def train_retinanet(): tqdm.write( f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, " - f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {losses['classification_loss'].item():5.2f} classification loss, {losses['regression_loss'].item():5.2f} regression loss, " + f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {losses['classification_loss'].item():5.4f} classification loss, {losses['regression_loss'].item():5.4f} regression loss, " f"{optim.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS" ) @@ -501,7 +498,7 @@ def train_retinanet(): if i == BENCHMARK: assert not math.isnan(loss) median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds - estimated_total_minutes = int(median_step_time * steps_in_train_epoch * e / 60) + estimated_total_minutes = int(median_step_time * steps_in_train_epoch * num_epochs / 60) print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m") print(f"epoch global_ops: {steps_in_train_epoch * GlobalCounters.global_ops:_}, " f"epoch global_mem: {steps_in_train_epoch * GlobalCounters.global_mem:_}")