From d8d7ac1bb1575d8071f6b13699d20e133fc2507a Mon Sep 17 00:00:00 2001 From: chenyu Date: Sun, 30 Mar 2025 22:42:52 -0400 Subject: [PATCH] fix bert free_intermediates (#9633) fix when only run eval `TRAIN=0 BERT_SIZE=tiny examples/mlperf/training_submission_v5.0/tinycorp/benchmarks/bert/implementations/tinybox_green/dev_beam.sh` --- examples/mlperf/model_train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index aaf4e8f219..56a3f728da 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -826,7 +826,7 @@ def train_bert(): if MLLOGGER and RUNMLPERF: MLLOGGER.start(key=mllog_constants.EVAL_START, value=None, metadata={"epoch_num": i*BS, "step_num": i}) if getenv("RESET_STEP", 0): train_step_bert.reset() - else: train_step_bert.captured.free_intermediates() + elif train_step_bert.captured is not None: train_step_bert.captured.free_intermediates() eval_lm_losses = [] eval_clsf_losses = [] eval_lm_accs = [] @@ -860,7 +860,7 @@ def train_bert(): return if getenv("RESET_STEP", 0): eval_step_bert.reset() - else: eval_step_bert.captured.free_intermediates() + elif eval_step_bert.captured is not None: eval_step_bert.captured.free_intermediates() del eval_data avg_lm_loss = sum(eval_lm_losses) / len(eval_lm_losses) avg_clsf_loss = sum(eval_clsf_losses) / len(eval_clsf_losses)