check for CKPT when target metric is reached before saving

This commit is contained in:
Francis Lata
2025-03-02 00:41:08 -08:00
parent 3ac4ae5870
commit 27ec792c19

View File

@@ -595,7 +595,7 @@ def train_retinanet():
if val_metric >= target_metric:
print(colored(f"target metric reached: {val_metric:.2f}/{target_metric:.2f}"))
safe_save(get_state_dict(model), fn)
if getenv("CKPT", 1): safe_save(get_state_dict(model), fn)
break
def train_unet3d():