check config constraints for the initial config (#685)

* check config constraints for the initial config

* default config value
This commit is contained in:
Chi Wang
2022-08-15 05:30:23 -07:00
committed by GitHub
parent 2e8e3937ef
commit 5e1059ab82
3 changed files with 40 additions and 26 deletions

View File

@@ -752,6 +752,9 @@ class BlendSearch(Searcher):
if result: # tried before
return None
elif result is None: # not tried before
if self._violate_config_constriants(config, config_signature):
# violate config constraints
return None
self._result[config_signature] = {}
else: # running but no result yet
return None
@@ -772,6 +775,32 @@ class BlendSearch(Searcher):
config[INCUMBENT_RESULT] = choice_thread.best_result
return config
def _violate_config_constriants(self, config, config_signature):
"""check if config violates config constraints.
If so, set the result to worst and return True.
"""
if not self._config_constraints:
return False
for constraint in self._config_constraints:
func, sign, threshold = constraint
value = func(config)
if (
sign == "<="
and value > threshold
or sign == ">="
and value < threshold
or sign == ">"
and value <= threshold
or sign == "<"
and value > threshold
):
self._result[config_signature] = {
self._metric: np.inf * self._ls.metric_op,
"time_total_s": 1,
}
return True
return False
def _should_skip(self, choice, trial_id, config, space) -> bool:
"""if config is None or config's result is known or constraints are violated
return True; o.w. return False
@@ -780,28 +809,10 @@ class BlendSearch(Searcher):
return True
config_signature = self._ls.config_signature(config, space)
exists = config_signature in self._result
# check constraints
if not exists and self._config_constraints:
for constraint in self._config_constraints:
func, sign, threshold = constraint
value = func(config)
if (
sign == "<="
and value > threshold
or sign == ">="
and value < threshold
or sign == ">"
and value <= threshold
or sign == "<"
and value > threshold
):
self._result[config_signature] = {
self._metric: np.inf * self._ls.metric_op,
"time_total_s": 1,
}
exists = True
break
if exists: # suggested before
if not exists:
# check constraints
exists = self._violate_config_constriants(config, config_signature)
if exists: # suggested before (including violate constraints)
if choice >= 0: # not fallback to rs
result = self._result.get(config_signature)
if result: # finished

View File

@@ -52,8 +52,8 @@ class MyRegularizedGreedyForest(SKLearnEstimator):
@classmethod
def size(cls, config):
max_leaves = int(round(config["max_leaf"]))
n_estimators = int(round(config["n_iter"]))
max_leaves = int(round(config.get("max_leaf", 1)))
n_estimators = int(round(config.get("n_iter", 1)))
return (max_leaves * 3 + (max_leaves - 1) * 4 + 1.0) * n_estimators * 8
@classmethod

View File

@@ -12,14 +12,17 @@ def test_config_constraint():
else:
return 0
tune.run(
analysis = tune.run(
evaluate_config_dict,
config={
"x": tune.qloguniform(lower=1, upper=100000, q=1),
"y": tune.qrandint(lower=2, upper=100000, q=2),
},
config_constraints=[(config_constraint, ">", 0.5)],
config_constraints=[(config_constraint, "<", 0.5)],
metric="metric",
mode="max",
num_samples=100,
)
assert analysis.best_config["x"] > analysis.best_config["y"]
assert analysis.trials[0].config["x"] > analysis.trials[0].config["y"]