From 27ec792c19cbdbee82861abdd2ffc9bf7075728c Mon Sep 17 00:00:00 2001 From: Francis Lata Date: Sun, 2 Mar 2025 00:41:08 -0800 Subject: [PATCH] check for CKPT when target metric is reached before saving --- examples/mlperf/model_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 1e2d7bbb17..8f3e3d07c3 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -595,7 +595,7 @@ def train_retinanet(): if val_metric >= target_metric: print(colored(f"target metric reached: {val_metric:.2f}/{target_metric:.2f}")) - safe_save(get_state_dict(model), fn) + if getenv("CKPT", 1): safe_save(get_state_dict(model), fn) break def train_unet3d():