diff --git a/examples/mlperf/dataloader.py b/examples/mlperf/dataloader.py index e1aace4fd2..aaef33fddb 100644 --- a/examples/mlperf/dataloader.py +++ b/examples/mlperf/dataloader.py @@ -207,14 +207,16 @@ def load_datasample(file_and_offset:Tuple[str, int]) -> List[dict]: return data[file_and_offset[1]] # Reference: https://github.com/mlcommons/training/blob/1c8a098ae3e70962a4f7422c0b0bd35ae639e357/language_model/tensorflow/bert/run_pretraining.py, Line 394 -def batch_load_train_bert(BS:int): +def batch_load_train_bert(BS:int, start_step:int = 0): from extra.datasets.wikipedia import get_wiki_train_files files = shuffle_parts(get_wiki_train_files()) dataset = [] - for f in files: + for f in tqdm(files, desc="Building dataset"): lists = [(f, o) for o in range(int(Path(f).stem.split("_")[3].split(".")[0]))] dataset.extend(lists) + dataset = dataset[start_step:] + active_set = deque(dataset[:1000]) remaining_set = deque(dataset[1000:]) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index af5011cd37..c3c46a1edc 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -414,6 +414,7 @@ def train_bert(): eval_step_freq = config["EVAL_STEP_FREQ"] = getenv("EVAL_STEP_FREQ", int((math.floor(0.05 * (230.23 * BS + 3000000) / 25000) * 25000) / BS)) # Round down save_ckpt_freq = config["SAVE_CKPT_FREQ"] = getenv("SAVE_CKPT_FREQ", 1000) keep_ckpt_amount = config["KEEP_CKPT_AMOUNT"] = getenv("KEEP_CKPT_AMOUNT", 5) + save_ckpt_dir = config["SAVE_CKPT_DIR"] = getenv("SAVE_CKPT_DIR", "./ckpts") init_ckpt = config["INIT_CKPT_DIR"] = getenv("INIT_CKPT_DIR", BASEDIR) loss_scaler = config["loss_scaler"] = getenv("LOSS_SCALER", 2**9 if dtypes.default_float == dtypes.float16 else 1.0) @@ -437,25 +438,27 @@ def train_bert(): assert 10000 <= (EVAL_BS * max_eval_steps), "Evaluation batchsize * max_eval_steps must greater or equal 10000 to iterate over full eval dataset" - # ** Log hparams ** + # ** Log run config ** for key, value in config.items(): print(f'HParam: "{key}": {value}') # ** Optimizer ** - skip_list = [v for k, v in get_state_dict(model).items() if "bias" in k or "LayerNorm" in k] - parameters = [x for x in parameters if x not in set(skip_list)] - optimizer = LAMB(parameters, lr=max_lr, eps=1e-6, wd=decay, adam=False) - optimizer_skip = LAMB(skip_list, lr=max_lr, eps=1e-6, wd=0.0, adam=False) - optimizer_group = OptimizerGroup(optimizer, optimizer_skip) + parameters_no_wd = [v for k, v in get_state_dict(model).items() if "bias" in k or "LayerNorm" in k] + parameters = [x for x in parameters if x not in set(parameters_no_wd)] + optimizer_wd = LAMB(parameters, lr=max_lr, eps=1e-6, wd=decay, adam=False) + optimizer_no_wd = LAMB(parameters_no_wd, lr=max_lr, eps=1e-6, wd=0.0, adam=False) + optimizer_group = OptimizerGroup(optimizer_wd, optimizer_no_wd) # ** LR scheduler ** - scheduler = PolynomialDecayWithWarmup(optimizer, max_lr, 0, train_steps, warmup_steps, power=poly_power) + scheduler_wd = PolynomialDecayWithWarmup(optimizer_wd, max_lr, 0, train_steps, warmup_steps, power=poly_power) + scheduler_no_wd = PolynomialDecayWithWarmup(optimizer_no_wd, max_lr, 0, train_steps, warmup_steps, power=poly_power) + scheduler_group = LRSchedulerGroup(scheduler_wd, scheduler_no_wd) print(f"training with batch size {BS} for one epoch with {train_steps} steps") # ** resume from checkpointing ** start_step = 0 if ckpt:=getenv("RESUME", ""): - load_training_state(model, optimizer_group, scheduler, safe_load(ckpt)) - start_step = scheduler.epoch_counter.numpy().item() + load_training_state(model, optimizer_group, scheduler_group, safe_load(ckpt)) + start_step = int(scheduler_wd.epoch_counter.numpy().item()) print(f"resuming from {ckpt} at step {start_step}") # ** init wandb ** @@ -468,18 +471,18 @@ def train_bert(): BENCHMARK = getenv("BENCHMARK") eval_it = iter(batch_load_val_bert(EVAL_BS)) - train_it = iter(tqdm(batch_load_train_bert(BS), total=train_steps, disable=BENCHMARK)) + train_it = iter(tqdm(batch_load_train_bert(BS, start_step), total=train_steps, disable=BENCHMARK)) step_times = [] # ** train loop ** wc_start = time.perf_counter() - i, train_data = 0, get_data_bert(GPUS, train_it) + i, train_data = start_step, get_data_bert(GPUS, train_it) while train_data is not None and i < train_steps and not achieved: Tensor.training = True BEAM.value = TRAIN_BEAM st = time.perf_counter() GlobalCounters.reset() - loss = train_step_bert(model, optimizer_group, scheduler, loss_scaler, + loss = train_step_bert(model, optimizer_group, scheduler_group, loss_scaler, train_data["input_ids"], train_data["segment_ids"], train_data["input_mask"], train_data["masked_lm_positions"], \ train_data["masked_lm_ids"], train_data["masked_lm_weights"], train_data["next_sentence_labels"]) @@ -500,10 +503,10 @@ def train_bert(): tqdm.write( f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, " - f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {optimizer.lr.numpy()[0]:.6f} LR, " + f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {optimizer_wd.lr.numpy()[0]:.6f} LR, " f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS") if WANDB: - wandb.log({"lr": optimizer.lr.numpy(), "train/loss": loss, "train/step_time": cl - st, + wandb.log({"lr": optimizer_wd.lr.numpy(), "train/loss": loss, "train/step_time": cl - st, "train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt, "train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st)}) @@ -569,7 +572,7 @@ def train_bert(): # save model if achieved target if not achieved and avg_lm_acc >= target: wc_end = time.perf_counter() - if not os.path.exists(ckpt_dir := getenv('CKPT_DIR', "./ckpts")): os.mkdir(ckpt_dir) + if not os.path.exists(ckpt_dir := save_ckpt_dir): os.mkdir(ckpt_dir) fn = f"{ckpt_dir}/bert-large.safe" safe_save(get_state_dict(model), fn) print(f" *** Model saved to {fn} ***") @@ -584,13 +587,13 @@ def train_bert(): break if getenv("CKPT") and i % save_ckpt_freq == 0: - if not os.path.exists(ckpt_dir := getenv('CKPT_DIR', "./ckpts")): os.mkdir(ckpt_dir) + if not os.path.exists(ckpt_dir := save_ckpt_dir): os.mkdir(ckpt_dir) if WANDB and wandb.run is not None: fn = f"{ckpt_dir}/{time.strftime('%Y%m%d_%H%M%S')}_{wandb.run.id}.safe" else: fn = f"{ckpt_dir}/{time.strftime('%Y%m%d_%H%M%S')}.safe" print(f"saving ckpt to {fn}") - safe_save(get_training_state(model, optimizer_group, scheduler), fn) + safe_save(get_training_state(model, optimizer_group, scheduler_group), fn) ckpt_files = [f for f in os.listdir(ckpt_dir) if os.path.isfile(os.path.join(ckpt_dir, f))] ckpt_files.sort(key=lambda x: os.path.getmtime(os.path.join(ckpt_dir, x))) while len(ckpt_files) > keep_ckpt_amount: