enable lr scheduler and fix benchmark timing

This commit is contained in:
Francis Lata
2025-01-22 09:56:38 -08:00
parent 66ff6cb37a
commit efe64ebeaf

View File

@@ -377,7 +377,6 @@ def train_retinanet():
return x.shard(GPUS, axis=0).realize(), y_bboxes.shard(GPUS, axis=0), y_labels.shard(GPUS, axis=0), matches.shard(GPUS, axis=0), anchors.shard(GPUS, axis=0), cookie
def _create_lr_scheduler(optim, start_iter, warmup_iters, warmup_factor):
# TODO: refactor this a bit more so we don't have to recreate it, unlike what MLPerf script is doing
def _lr_lambda(e):
e = e + start_iter
if e >= warmup_iters: return 1.0
@@ -394,7 +393,7 @@ def train_retinanet():
loss.backward()
optim.step()
# lr_scheduler.step()
lr_scheduler.step()
return loss.realize(), losses
@@ -435,6 +434,8 @@ def train_retinanet():
for p in params: p.to_(GPUS)
step_times, start_epoch = [], 0
# ** optimizer **
optim = Adam(params, lr=lr)
@@ -442,21 +443,17 @@ def train_retinanet():
train_dataset = COCO(download_dataset(BASE_DIR, "train"))
val_dataset = COCO(download_dataset(BASE_DIR, "validation"))
# ** lr scheduler **
config["steps_in_train_epoch"] = steps_in_train_epoch = round_up(len(train_dataset.imgs.keys()), bs) // bs
step_times = []
start_iter, warmup_iters = start_epoch * steps_in_train_epoch, lr_warmup_epochs * steps_in_train_epoch
lr_scheduler = _create_lr_scheduler(optim, start_iter, warmup_iters, lr_warmup_factor)
# ** training loop **
for e in range(1, num_epochs + 1):
for e in range(start_epoch, num_epochs):
train_dataloader = batch_load_retinanet(train_dataset, False, Path(BASE_DIR), batch_size=bs, seed=seed)
it = iter(tqdm(train_dataloader, total=steps_in_train_epoch, desc=f"epoch {e}", disable=BENCHMARK))
i, proc = 0, _data_get(it)
# if e < LR_WARMUP_EPOCHS:
# start_iter, warmup_iters = e * train_dataset_len, LR_WARMUP_EPOCHS * train_dataset_len
# lr_scheduler = _create_lr_scheduler(optim, start_iter, warmup_iters, LR_WARMUP_FACTOR)
# else: lr_scheduler = None
lr_scheduler = None
prev_cookies = []
st = time.perf_counter()
@@ -484,7 +481,7 @@ def train_retinanet():
tqdm.write(
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {losses['classification_loss'].item():5.2f} classification loss, {losses['regression_loss'].item():5.2f} regression loss, "
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {losses['classification_loss'].item():5.4f} classification loss, {losses['regression_loss'].item():5.4f} regression loss, "
f"{optim.lr.numpy()[0]:.6f} LR, {GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS"
)
@@ -501,7 +498,7 @@ def train_retinanet():
if i == BENCHMARK:
assert not math.isnan(loss)
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
estimated_total_minutes = int(median_step_time * steps_in_train_epoch * e / 60)
estimated_total_minutes = int(median_step_time * steps_in_train_epoch * num_epochs / 60)
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
print(f"epoch global_ops: {steps_in_train_epoch * GlobalCounters.global_ops:_}, "
f"epoch global_mem: {steps_in_train_epoch * GlobalCounters.global_mem:_}")