diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index e243bf9c0f..f864760100 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -1142,12 +1142,7 @@ def train_bert(): train_data["masked_lm_ids"], train_data["masked_lm_weights"], train_data["next_sentence_labels"]) pt = time.perf_counter() - - try: - next_data = next(train_it) - except StopIteration: - next_data = None - + next_data = next(train_it) dt = time.perf_counter() device_str = parameters[0].device if isinstance(parameters[0].device, str) else f"{parameters[0].device[0]} * {len(parameters[0].device)}"