mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
check for CKPT when target metric is reached before saving
This commit is contained in:
@@ -595,7 +595,7 @@ def train_retinanet():
|
||||
|
||||
if val_metric >= target_metric:
|
||||
print(colored(f"target metric reached: {val_metric:.2f}/{target_metric:.2f}"))
|
||||
safe_save(get_state_dict(model), fn)
|
||||
if getenv("CKPT", 1): safe_save(get_state_dict(model), fn)
|
||||
break
|
||||
|
||||
def train_unet3d():
|
||||
|
||||
Reference in New Issue
Block a user