RESET_STEP in bert setup and beam (#9248)

running dev_beam migh OOM without it but runs fine in real run.
This commit is contained in:
chenyu
2025-02-25 19:15:10 -05:00
committed by GitHub
parent 2676c9d46e
commit 979e84f30e
5 changed files with 5 additions and 3 deletions

View File

@@ -818,7 +818,7 @@ def train_bert():
if i % eval_step_freq == 0 or (BENCHMARK and i == BENCHMARK) or i == train_steps:
if MLLOGGER and RUNMLPERF:
MLLOGGER.start(key=mllog_constants.EVAL_START, value=None, metadata={"epoch_num": i*BS, "step_num": i})
if getenv("RESET_STEP", 0) or INITMLPERF: train_step_bert.reset()
if getenv("RESET_STEP", 0): train_step_bert.reset()
else: train_step_bert.captured.free_intermediates()
eval_lm_losses = []
eval_clsf_losses = []

View File

@@ -8,6 +8,7 @@ export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024
export IGNORE_JIT_FIRST_BEAM=1
export BASEDIR="/raid/datasets/wiki"
export RESET_STEP=1
export BENCHMARK=10 DEBUG=2
python3 examples/mlperf/model_train.py

View File

@@ -17,7 +17,7 @@ DATETIME=$(date "+%m%d%H%M")
LOGFILE="bert_green_${DATETIME}_${SEED}.log"
# init
BENCHMARK=10 INITMLPERF=1 python3 examples/mlperf/model_train.py | tee $LOGFILE
BENCHMARK=10 INITMLPERF=1 RESET_STEP=1 python3 examples/mlperf/model_train.py | tee $LOGFILE
# run
PARALLEL=0 RUNMLPERF=1 python3 examples/mlperf/model_train.py | tee -a $LOGFILE

View File

@@ -8,6 +8,7 @@ export BEAM=3 BEAM_UOPS_MAX=3000 BEAM_UPCAST_MAX=256 BEAM_LOCAL_MAX=1024
export IGNORE_JIT_FIRST_BEAM=1
export BASEDIR="/raid/datasets/wiki"
export RESET_STEP=1
export BENCHMARK=10 DEBUG=2
python3 examples/mlperf/model_train.py

View File

@@ -17,7 +17,7 @@ DATETIME=$(date "+%m%d%H%M")
LOGFILE="bert_red_${DATETIME}_${SEED}.log"
# init
BENCHMARK=10 INITMLPERF=1 python3 examples/mlperf/model_train.py | tee $LOGFILE
BENCHMARK=10 INITMLPERF=1 RESET_STEP=1 python3 examples/mlperf/model_train.py | tee $LOGFILE
# run
PARALLEL=0 RUNMLPERF=1 python3 examples/mlperf/model_train.py | tee -a $LOGFILE