diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 76032db841..dba4b40b54 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -363,6 +363,7 @@ def train_retinanet(): config["gpus"] = GPUS = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))] for x in GPUS: Device[x] + print(f"training on {GPUS}") def _freeze_backbone_layers(backbone, trainable_layers, loaded_keys): model_layers = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers] @@ -403,24 +404,13 @@ def train_retinanet(): config["lr"] = lr = 1e-4 config["lr_warmup_epochs"] = lr_warmup_epochs = 1 config["lr_warmup_factor"] = lr_warmup_factor = 1e-3 - config["seed"] = seed = getenv("SEED", random.SystemRandom().randint(0, 2**32 - 1)) - config["bs"] = bs = getenv("BS", 128) - config["num_epochs"] = num_epochs = getenv("EPOCHS", 4) + config["seed"] = SEED = getenv("SEED", random.SystemRandom().randint(0, 2**32 - 1)) + config["bs"] = BS = getenv("BS", 128) + config["epochs"] = EPOCHS = getenv("EPOCHS", 4) - if seed: - Tensor.manual_seed(seed) - np.random.seed(seed=seed) - - # ** initialize wandb ** - if (WANDB := getenv("WANDB")): - import wandb - - wandb_args = {"project": "MLPerf-RetinaNet"} - if (wandb_id := getenv("WANDB_RESUME", "")): - wandb_args["id"] = wandb_id - wandb_args["resume"] = "must" - - wandb.init(config=config, **wandb_args) + if SEED: + Tensor.manual_seed(SEED) + np.random.seed(seed=SEED) # ** model initializers ** resnet.BatchNorm = FrozenBatchNorm2d @@ -445,13 +435,32 @@ def train_retinanet(): val_dataset = COCO(download_dataset(BASE_DIR, "validation")) # ** lr scheduler ** - config["steps_in_train_epoch"] = steps_in_train_epoch = round_up(len(train_dataset.imgs.keys()), bs) // bs + config["steps_in_train_epoch"] = steps_in_train_epoch = round_up(len(train_dataset.imgs.keys()), BS) // BS start_iter, warmup_iters = start_epoch * steps_in_train_epoch, lr_warmup_epochs * steps_in_train_epoch lr_scheduler = _create_lr_scheduler(optim, start_iter, warmup_iters, lr_warmup_factor) + # ** resume from checkpointing ** + if ckpt := getenv("RESUME", ""): + load_training_state(model, optim, lr_scheduler, safe_load(ckpt)) + start_epoch = int(lr_scheduler.epoch_counter.item() / steps_in_train_epoch) + print(f"resuming from {ckpt} at epoch {start_epoch}") + + # ** initialize wandb ** + if WANDB := getenv("WANDB"): + import wandb + + wandb_args = {"project": "MLPerf-RetinaNet"} + if wandb_id := getenv("WANDB_RESUME", ""): + wandb_args["id"] = wandb_id + wandb_args["resume"] = "must" + + wandb.init(config=config, **wandb_args) + + print(f"training with batch size {BS} for {EPOCHS} epochs") + # ** training loop ** - for e in range(start_epoch, num_epochs): - train_dataloader = batch_load_retinanet(train_dataset, False, Path(BASE_DIR), batch_size=bs, seed=seed) + for e in range(start_epoch, EPOCHS): + train_dataloader = batch_load_retinanet(train_dataset, False, Path(BASE_DIR), batch_size=BS, seed=SEED) it = iter(tqdm(train_dataloader, total=steps_in_train_epoch, desc=f"epoch {e}", disable=BENCHMARK)) i, proc = 0, _data_get(it) @@ -499,12 +508,21 @@ def train_retinanet(): if i == BENCHMARK: assert not math.isnan(loss) median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds - estimated_total_minutes = int(median_step_time * steps_in_train_epoch * num_epochs / 60) + estimated_total_minutes = int(median_step_time * steps_in_train_epoch * EPOCHS / 60) print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m") print(f"epoch global_ops: {steps_in_train_epoch * GlobalCounters.global_ops:_}, " f"epoch global_mem: {steps_in_train_epoch * GlobalCounters.global_mem:_}") return + if getenv("CKPT"): + if not os.path.exists(ckpt_dir := Path(getenv("CKPT_DIR", "./ckpts"))): os.mkdir(ckpt_dir) + if WANDB and wandb.run is not None: + fn = ckpt_dir / Path(f"{time.strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}_e{e}.safe") + else: + fn = ckpt_dir / Path(f"{time.strftime('%Y%m%d_%H%M%S')}_e{e}.safe") + print(f"saving ckpt to {fn}") + safe_save(get_training_state(model, optim, lr_scheduler), fn) + def train_unet3d(): """ Trains the UNet3D model.