diff --git a/examples/mlperf/model_train.py b/examples/mlperf/model_train.py index 1e2d7bbb17..8f3e3d07c3 100644 --- a/examples/mlperf/model_train.py +++ b/examples/mlperf/model_train.py @@ -595,7 +595,7 @@ def train_retinanet(): if val_metric >= target_metric: print(colored(f"target metric reached: {val_metric:.2f}/{target_metric:.2f}")) - safe_save(get_state_dict(model), fn) + if getenv("CKPT", 1): safe_save(get_state_dict(model), fn) break def train_unet3d():