mirror of
https://github.com/microsoft/autogen.git
synced 2026-04-20 03:02:16 -04:00
check config constraints for the initial config (#685)
* check config constraints for the initial config * default config value
This commit is contained in:
@@ -752,6 +752,9 @@ class BlendSearch(Searcher):
|
||||
if result: # tried before
|
||||
return None
|
||||
elif result is None: # not tried before
|
||||
if self._violate_config_constriants(config, config_signature):
|
||||
# violate config constraints
|
||||
return None
|
||||
self._result[config_signature] = {}
|
||||
else: # running but no result yet
|
||||
return None
|
||||
@@ -772,6 +775,32 @@ class BlendSearch(Searcher):
|
||||
config[INCUMBENT_RESULT] = choice_thread.best_result
|
||||
return config
|
||||
|
||||
def _violate_config_constriants(self, config, config_signature):
|
||||
"""check if config violates config constraints.
|
||||
If so, set the result to worst and return True.
|
||||
"""
|
||||
if not self._config_constraints:
|
||||
return False
|
||||
for constraint in self._config_constraints:
|
||||
func, sign, threshold = constraint
|
||||
value = func(config)
|
||||
if (
|
||||
sign == "<="
|
||||
and value > threshold
|
||||
or sign == ">="
|
||||
and value < threshold
|
||||
or sign == ">"
|
||||
and value <= threshold
|
||||
or sign == "<"
|
||||
and value > threshold
|
||||
):
|
||||
self._result[config_signature] = {
|
||||
self._metric: np.inf * self._ls.metric_op,
|
||||
"time_total_s": 1,
|
||||
}
|
||||
return True
|
||||
return False
|
||||
|
||||
def _should_skip(self, choice, trial_id, config, space) -> bool:
|
||||
"""if config is None or config's result is known or constraints are violated
|
||||
return True; o.w. return False
|
||||
@@ -780,28 +809,10 @@ class BlendSearch(Searcher):
|
||||
return True
|
||||
config_signature = self._ls.config_signature(config, space)
|
||||
exists = config_signature in self._result
|
||||
# check constraints
|
||||
if not exists and self._config_constraints:
|
||||
for constraint in self._config_constraints:
|
||||
func, sign, threshold = constraint
|
||||
value = func(config)
|
||||
if (
|
||||
sign == "<="
|
||||
and value > threshold
|
||||
or sign == ">="
|
||||
and value < threshold
|
||||
or sign == ">"
|
||||
and value <= threshold
|
||||
or sign == "<"
|
||||
and value > threshold
|
||||
):
|
||||
self._result[config_signature] = {
|
||||
self._metric: np.inf * self._ls.metric_op,
|
||||
"time_total_s": 1,
|
||||
}
|
||||
exists = True
|
||||
break
|
||||
if exists: # suggested before
|
||||
if not exists:
|
||||
# check constraints
|
||||
exists = self._violate_config_constriants(config, config_signature)
|
||||
if exists: # suggested before (including violate constraints)
|
||||
if choice >= 0: # not fallback to rs
|
||||
result = self._result.get(config_signature)
|
||||
if result: # finished
|
||||
|
||||
@@ -52,8 +52,8 @@ class MyRegularizedGreedyForest(SKLearnEstimator):
|
||||
|
||||
@classmethod
|
||||
def size(cls, config):
|
||||
max_leaves = int(round(config["max_leaf"]))
|
||||
n_estimators = int(round(config["n_iter"]))
|
||||
max_leaves = int(round(config.get("max_leaf", 1)))
|
||||
n_estimators = int(round(config.get("n_iter", 1)))
|
||||
return (max_leaves * 3 + (max_leaves - 1) * 4 + 1.0) * n_estimators * 8
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -12,14 +12,17 @@ def test_config_constraint():
|
||||
else:
|
||||
return 0
|
||||
|
||||
tune.run(
|
||||
analysis = tune.run(
|
||||
evaluate_config_dict,
|
||||
config={
|
||||
"x": tune.qloguniform(lower=1, upper=100000, q=1),
|
||||
"y": tune.qrandint(lower=2, upper=100000, q=2),
|
||||
},
|
||||
config_constraints=[(config_constraint, ">", 0.5)],
|
||||
config_constraints=[(config_constraint, "<", 0.5)],
|
||||
metric="metric",
|
||||
mode="max",
|
||||
num_samples=100,
|
||||
)
|
||||
|
||||
assert analysis.best_config["x"] > analysis.best_config["y"]
|
||||
assert analysis.trials[0].config["x"] > analysis.trials[0].config["y"]
|
||||
|
||||
Reference in New Issue
Block a user