From d2cfbd8a4d2ebb82d6caf3ee8f03d12d39107445 Mon Sep 17 00:00:00 2001 From: chenyu Date: Sun, 16 Mar 2025 17:21:20 -0400 Subject: [PATCH] bert lower learning rate and total steps (#9466) closer to the other submission with BS=240. converged with 10% less epochs --- examples/mlperf/model_train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 83e1a2a107..d0f18bf487 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -658,9 +658,9 @@ def train_bert(): # ** hyperparameters ** BS = config["GLOBAL_BATCH_SIZE"] = getenv("BS", 11 * len(GPUS) if dtypes.default_float in (dtypes.float16, dtypes.bfloat16) else 8 * len(GPUS)) EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 1 * len(GPUS)) - max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.0002 * math.sqrt(BS/96)) + max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.00018 * math.sqrt(BS/96)) - train_steps = config["TRAIN_STEPS"] = getenv("TRAIN_STEPS", 3630000 // BS) + train_steps = config["TRAIN_STEPS"] = getenv("TRAIN_STEPS", 3300000 // BS) warmup_steps = config["NUM_WARMUP_STEPS"] = getenv("NUM_WARMUP_STEPS", 1) max_eval_steps = config["MAX_EVAL_STEPS"] = getenv("MAX_EVAL_STEPS", (10000 + EVAL_BS - 1) // EVAL_BS) # EVAL_BS * MAX_EVAL_STEPS >= 10000 eval_step_freq = config["EVAL_STEP_FREQ"] = getenv("EVAL_STEP_FREQ", int((math.floor(0.05 * (230.23 * BS + 3000000) / 25000) * 25000) / BS)) # Round down