From 43d3a75d6c36856dd741c49c95dfad1787271390 Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 14 Apr 2025 08:53:44 -0400 Subject: [PATCH] increase bert max train_steps (#9883) --- examples/mlperf/model_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index f3ab8f6405..6f265eb9d9 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -897,7 +897,7 @@ def train_bert(): opt_lamb_beta_1 = config["OPT_LAMB_BETA_1"] = getenv("OPT_LAMB_BETA_1", 0.9) opt_lamb_beta_2 = config["OPT_LAMB_BETA_2"] = getenv("OPT_LAMB_BETA_2", 0.999) - train_steps = config["TRAIN_STEPS"] = getenv("TRAIN_STEPS", 3300000 // BS) + train_steps = config["TRAIN_STEPS"] = getenv("TRAIN_STEPS", 3600000 // BS) warmup_steps = config["NUM_WARMUP_STEPS"] = getenv("NUM_WARMUP_STEPS", 1) max_eval_steps = config["MAX_EVAL_STEPS"] = getenv("MAX_EVAL_STEPS", (10000 + EVAL_BS - 1) // EVAL_BS) # EVAL_BS * MAX_EVAL_STEPS >= 10000 eval_step_freq = config["EVAL_STEP_FREQ"] = getenv("EVAL_STEP_FREQ", int((math.floor(0.05 * (230.23 * BS + 3000000) / 25000) * 25000) / BS)) # Round down