diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index c581eb0324..ab90580fad 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -181,13 +181,15 @@ def train_resnet(): median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds estimated_total_minutes = int(median_step_time * steps_in_train_epoch * epochs / 60) print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m") + print(f"epoch global_ops: {steps_in_train_epoch * GlobalCounters.global_ops:_}, " + f"epoch global_mem: {steps_in_train_epoch * GlobalCounters.global_mem:_}") # if we are doing beam search, run the first eval too if (TRAIN_BEAM or EVAL_BEAM) and e == start_epoch: break return # ** eval loop ** if (e + 1 - eval_start_epoch) % eval_epochs == 0 and steps_in_val_epoch > 0: - train_step.reset() # free the train step memory :( + if getenv("RESET_STEP", 1): train_step.reset() # free the train step memory :( eval_loss = [] eval_times = [] eval_top_1_acc = [] @@ -214,7 +216,7 @@ def train_resnet(): et = time.time() eval_times.append(et - st) - eval_step.reset() + if getenv("RESET_STEP", 1): eval_step.reset() total_loss = sum(eval_loss) / len(eval_loss) total_top_1 = sum(eval_top_1_acc) / len(eval_top_1_acc) total_fw_time = sum(eval_times) / len(eval_times)