diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index dc2d2b3356..a89ef4ea32 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -410,7 +410,7 @@ def train_bert(): train_steps = config["TRAIN_STEPS"] = getenv("TRAIN_STEPS", 4800000 // BS) warmup_steps = config["NUM_WARMUP_STEPS"] = getenv("NUM_WARMUP_STEPS", train_steps // 10) max_eval_steps = config["MAX_EVAL_STEPS"] = getenv("MAX_EVAL_STEPS", (10000 + EVAL_BS - 1) // EVAL_BS) # EVAL_BS * MAX_EVAL_STEPS >= 10000 - eval_step_freq = config["EVAL_STEP_FREQ"] = int((math.floor(0.05 * (230.23 * BS + 3000000) / 25000) * 25000) / BS) # Round down + eval_step_freq = config["EVAL_STEP_FREQ"] = getenv("EVAL_STEP_FREQ", int((math.floor(0.05 * (230.23 * BS + 3000000) / 25000) * 25000) / BS)) # Round down save_ckpt_freq = config["SAVE_CKPT_FREQ"] = getenv("SAVE_CKPT_FREQ", 1000) keep_ckpt_amount = config["KEEP_CKPT_AMOUNT"] = getenv("KEEP_CKPT_AMOUNT", 5) init_ckpt = config["INIT_CKPT_DIR"] = getenv("INIT_CKPT_DIR", BASEDIR)