Fix checkpoint path issue

checkpoint path may be named dir_or_data instead of value
This commit is contained in:
Jing Dong
2022-12-16 14:41:33 +08:00
committed by GitHub
parent c1872861b6
commit 5778227a71

View File

@@ -261,7 +261,8 @@ if torch.cuda.is_available():
best_trained_model = nn.DataParallel(best_trained_model)
best_trained_model.to(device)
checkpoint_path = os.path.join(best_trial.checkpoint.value, "checkpoint")
checkpoint_value = getattr(best_trial.checkpoint, "dir_or_data", None) or best_trial.checkpoint.value
checkpoint_path = os.path.join(checkpoint_value, "checkpoint")
model_state, optimizer_state = torch.load(checkpoint_path)
best_trained_model.load_state_dict(model_state)
@@ -283,4 +284,4 @@ Files already downloaded and verified
Best trial test set accuracy: 0.6294
```
[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/tune_pytorch.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/tune_pytorch.ipynb)
[Link to notebook](https://github.com/microsoft/FLAML/blob/main/notebook/tune_pytorch.ipynb) | [Open in colab](https://colab.research.google.com/github/microsoft/FLAML/blob/main/notebook/tune_pytorch.ipynb)