don't init global search with points_to_evaluate unless evaluated_rewards is provided; handle callbacks in fit kwargs (#469)

This commit is contained in:
Chi Wang
2022-03-01 18:39:16 -08:00
committed by GitHub
parent df01031cfe
commit 31ac984c4b
4 changed files with 25 additions and 9 deletions

View File

@@ -959,10 +959,16 @@ class LGBMEstimator(BaseEstimator):
# when not trained, train at least one iter
self.params[self.ITER_HP] = max(max_iter, 1)
if self.HAS_CALLBACK:
kwargs_callbacks = kwargs.get("callbacks")
if kwargs_callbacks:
callbacks = kwargs_callbacks + self._callbacks(start_time, deadline)
kwargs.pop("callbacks")
else:
callbacks = self._callbacks(start_time, deadline)
self._fit(
X_train,
y_train,
callbacks=self._callbacks(start_time, deadline),
callbacks=callbacks,
**kwargs,
)
best_iteration = (
@@ -1821,10 +1827,7 @@ class TS_SKLearn(SKLearnEstimator):
"low_cost_init_value": False,
},
"lags": {
"domain": tune.randint(
lower=1, upper=int(np.sqrt(data_size[0]))
),
"domain": tune.randint(lower=1, upper=int(np.sqrt(data_size[0]))),
"init_value": 3,
},
}

View File

@@ -171,6 +171,7 @@ class BlendSearch(Searcher):
else:
sampler = None
try:
assert evaluated_rewards
self._gs = GlobalSearch(
space=gs_space,
metric=metric,
@@ -180,7 +181,7 @@ class BlendSearch(Searcher):
points_to_evaluate=points_to_evaluate,
evaluated_rewards=evaluated_rewards,
)
except ValueError:
except (AssertionError, ValueError):
self._gs = GlobalSearch(
space=gs_space,
metric=metric,

View File

@@ -214,7 +214,12 @@ class TestClassification(unittest.TestCase):
}
X_train = scipy.sparse.eye(900000)
y_train = np.random.randint(2, size=900000)
automl_experiment.fit(X_train=X_train, y_train=y_train, **automl_settings)
import xgboost as xgb
callback = xgb.callback.TrainingCallback()
automl_experiment.fit(
X_train=X_train, y_train=y_train, callbacks=[callback], **automl_settings
)
print(automl_experiment.predict(X_train))
print(automl_experiment.model)
print(automl_experiment.config_history)

View File

@@ -197,14 +197,21 @@ def test_searcher():
# sign of metric constraints must be <= or >=.
pass
searcher = BlendSearch(
metric="m", global_search_alg=searcher, metric_constraints=[("c", "<=", 1)]
metric="m",
global_search_alg=searcher,
metric_constraints=[("c", "<=", 1)],
points_to_evaluate=[{"a": 1, "b": 0.01}],
)
searcher.set_search_properties(
metric="m2", config=config, setting={"time_budget_s": 0}
)
c = searcher.suggest("t1")
searcher.on_trial_complete("t1", {"config": c}, True)
print("t1", c)
c = searcher.suggest("t2")
print("t2", c)
c = searcher.suggest("t3")
print("t3", c)
searcher.on_trial_complete("t1", {"config": c}, True)
searcher.on_trial_complete("t2", {"config": c, "m2": 1, "c": 2, "time_total_s": 1})
config1 = config.copy()
config1["_choice_"] = 0