From f59517754e6fcfb6e0b294b5610b2e5a8af86590 Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 30 Sep 2024 09:39:04 -0400 Subject: [PATCH] add RESET_STEP in bert to control reset (#6818) same as resnet --- examples/mlperf/model_train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index b3aac3e53a..4c661ebd5a 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -802,7 +802,7 @@ def train_bert(): if i % eval_step_freq == 0 or (BENCHMARK and i == BENCHMARK): if MLLOGGER and RUNMLPERF: MLLOGGER.start(key=mllog_constants.EVAL_START, value=None, metadata={"epoch_num": 1, "epoch_count": 1, "step_num": i}) - train_step_bert.reset() + if getenv("RESET_STEP", 1): train_step_bert.reset() eval_lm_losses = [] eval_clsf_losses = [] eval_lm_accs = [] @@ -840,7 +840,7 @@ def train_bert(): MLLOGGER.event(key=mllog_constants.INIT_STOP, value=None) return - eval_step_bert.reset() + if getenv("RESET_STEP", 1): eval_step_bert.reset() del eval_data, eval_result avg_lm_loss = sum(eval_lm_losses) / len(eval_lm_losses) avg_clsf_loss = sum(eval_clsf_losses) / len(eval_clsf_losses)