diff --git a/examples/mlperf/helpers.py b/examples/mlperf/helpers.py index c601bf490d..df9c98fb3b 100644 --- a/examples/mlperf/helpers.py +++ b/examples/mlperf/helpers.py @@ -227,13 +227,13 @@ def get_data_bert(GPUS:list[str], it): for key in data.keys(): data[key].shard_(GPUS, axis=0) return data -def get_fake_data_bert(GPUS:list[str], BS:int): +def get_fake_data_bert(BS:int): return { - "input_ids": Tensor.empty((BS, 512), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0), - "input_mask": Tensor.empty((BS, 512), dtype=dtypes.default_float).contiguous().shard_(GPUS, axis=0), - "segment_ids": Tensor.empty((BS, 512), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0), - "masked_lm_positions": Tensor.empty((BS, 76), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0), - "masked_lm_ids": Tensor.empty((BS, 76), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0), - "masked_lm_weights": Tensor.empty((BS, 76), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0), - "next_sentence_labels": Tensor.empty((BS, 1), dtype=dtypes.float32).contiguous().shard_(GPUS, axis=0), + "input_ids": Tensor.empty((BS, 512), dtype=dtypes.float32), + "input_mask": Tensor.empty((BS, 512), dtype=dtypes.default_float), + "segment_ids": Tensor.empty((BS, 512), dtype=dtypes.float32), + "masked_lm_positions": Tensor.empty((BS, 76), dtype=dtypes.float32), + "masked_lm_ids": Tensor.empty((BS, 76), dtype=dtypes.float32), + "masked_lm_weights": Tensor.empty((BS, 76), dtype=dtypes.float32), + "next_sentence_labels": Tensor.empty((BS, 1), dtype=dtypes.float32), } diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 932b90bb1f..2c6c621aef 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -747,18 +747,22 @@ def train_bert(): eval_it = iter(batch_load_val_bert(EVAL_BS)) train_it = iter(tqdm(batch_load_train_bert(BS), total=train_steps, disable=BENCHMARK)) for _ in range(start_step): next(train_it) # Fast forward - + else: + # repeat fake data + def repeat_fake(bs): + while True: yield get_fake_data_bert(bs) + eval_it = iter(repeat_fake(EVAL_BS)) + train_it = iter(repeat_fake(BS)) step_times = [] # ** train loop ** wc_start = time.perf_counter() + + i, train_data = start_step, get_data_bert(GPUS, train_it) + if RUNMLPERF: - # only load real data with RUNMLPERF - i, train_data = start_step, get_data_bert(GPUS, train_it) if MLLOGGER: MLLOGGER.start(key=mllog_constants.EPOCH_START, value=i*BS, metadata={"epoch_num": i*BS}) - else: - i, train_data = start_step, get_fake_data_bert(GPUS, BS) while train_data is not None and i < train_steps and not achieved: Tensor.training = True @@ -772,10 +776,7 @@ def train_bert(): pt = time.perf_counter() try: - if RUNMLPERF: - next_data = get_data_bert(GPUS, train_it) - else: - next_data = get_fake_data_bert(GPUS, BS) + next_data = get_data_bert(GPUS, train_it) except StopIteration: next_data = None @@ -821,10 +822,7 @@ def train_bert(): BEAM.value = EVAL_BEAM for j in tqdm(range(max_eval_steps), desc="Evaluating", total=max_eval_steps, disable=BENCHMARK): - if RUNMLPERF: - eval_data = get_data_bert(GPUS, eval_it) - else: - eval_data = get_fake_data_bert(GPUS, EVAL_BS) + eval_data = get_data_bert(GPUS, eval_it) GlobalCounters.reset() st = time.time()