diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 2b3214ae9d..945c00abc3 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -410,12 +410,12 @@ def train_retinanet(): # ** hyperparameters ** # using https://github.com/mlcommons/logging/blob/96d0acee011ba97702532dcc39e6eeaa99ebef24/mlperf_logging/rcp_checker/training_4.1.0/rcps_ssd.json#L3 - config["lr"] = lr = 1e-4 - config["lr_warmup_epochs"] = lr_warmup_epochs = 1 - config["lr_warmup_factor"] = lr_warmup_factor = 1e-3 config["seed"] = SEED = getenv("SEED", random.SystemRandom().randint(0, 2**32 - 1)) config["bs"] = BS = getenv("BS", 128) config["epochs"] = EPOCHS = getenv("EPOCHS", 4) + config["lr"] = lr = 1e-4 * (BS / 256) + config["lr_warmup_epochs"] = lr_warmup_epochs = 1 + config["lr_warmup_factor"] = lr_warmup_factor = 1e-3 if SEED: Tensor.manual_seed(SEED)