backup & recover global vars for nested tune.run (#584)

* backup & recover global vars for nested tune.run

* ensure recovering global vars before return
This commit is contained in:
Chi Wang
2022-06-14 11:03:54 -07:00
committed by GitHub
parent 65fa72d583
commit 1111d6d43a
2 changed files with 155 additions and 85 deletions

View File

@@ -20,6 +20,37 @@ logger.addHandler(logging.FileHandler("logs/tune.log"))
logger.setLevel(logging.INFO)
def test_nested_run():
from flaml import AutoML, tune
data, labels = sklearn.datasets.load_breast_cancer(return_X_y=True)
train_x, val_x, y_train, y_val = train_test_split(data, labels, test_size=0.25)
space_pca = {
"n_components": tune.uniform(0.5, 0.99),
}
def pca_flaml(config):
n_components = config["n_components"]
from sklearn.decomposition import PCA
pca = PCA(n_components)
X_train = pca.fit_transform(train_x)
X_val = pca.transform(val_x)
automl = AutoML()
automl.fit(X_train, y_train, X_val=X_val, y_val=y_val, time_budget=1)
return {"loss": automl.best_loss}
analysis = tune.run(
pca_flaml,
space_pca,
metric="loss",
mode="min",
num_samples=5,
local_dir="logs",
)
print(analysis.best_result)
def train_breast_cancer(config: dict):
# This is a simple training function to be passed into Tune
# Load dataset
@@ -182,7 +213,7 @@ def _test_xgboost(method="BlendSearch"):
logger.info(f"Best model parameters: {best_trial.config}")
def test_nested():
def test_nested_space():
from flaml import tune, CFO
search_space = {